Skip to content

Commit

Permalink
Misc. updates to CAVE-related scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeStrout committed Dec 13, 2024
1 parent c70d282 commit ff52457
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 81 deletions.
110 changes: 79 additions & 31 deletions scripts/cave_db_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
This script provides a raw SQL interface to a CAVE database.
Use with caution!
"""
import readline
import struct
import sys
from binascii import unhexlify
from typing import Any

import psycopg2
from geoalchemy2.elements import WKBElement
from shapely.wkb import loads as load_wkb
from sqlalchemy import create_engine, inspect
Expand All @@ -19,7 +21,6 @@
# Database connection parameters
DB_USER = "postgres"
DB_PASS = "Abracadabra is the passphrase"
DB_NAME = "dacey_human_fovea"
DB_HOST = "127.0.0.1" # Local proxy address; run Cloud SQL Auth Proxy
DB_PORT = 5432 # Default PostgreSQL port

Expand Down Expand Up @@ -106,16 +107,90 @@ def print_results(result: Result, batch_size: int = 20) -> None:
print(f"Error executing query: {str(e)}")


def select_database():
conn = psycopg2.connect(
host="127.0.0.1",
port="5432",
database="postgres", # Default database that always exists
user="postgres",
password="Abracadabra is the passphrase",
)
cur = conn.cursor()
cur.execute("SELECT datname FROM pg_database")
databases = [db[0] for db in cur.fetchall()]
databases.sort()
num_to_db = {}
next_num = 1
print("Available databases:")
for db in databases:
print(f" {' ' if next_num < 10 else ''}{next_num}. {db}")
num_to_db[next_num] = db
next_num += 1
print()
while True:
try:
choice: int | str = input("Enter DB name or number: ")
except EOFError:
print()
return None
if choice in databases:
return choice
choice = int(choice)
if choice in num_to_db:
return num_to_db[choice]


def do_command(command, connection, pending_commit=False) -> Any:
"""
Execute the given SQL command, or one of our extra commands
(quit, commit, or rollback). Return the new value for
pending_commit, indicating changes are pending.
"""
if command.lower() in ("quit", "exit"):
sys.exit()
elif command.lower() == "commit":
connection.commit()
print("Changes commited.")
pending_commit = False
elif command.lower() == "rollback":
connection.rollback()
print("Changes rolled back.")
pending_commit = False
else:
# SQL (possibly multi-line) command
while not command.endswith(";"):
try:
command += " " + input("...>")
except EOFError:
command = ""
break
if not command:
return pending_commit
try:
result = connection.execute(sql(command))
print_results(result)
if not result.returns_rows and result.rowcount > 0:
pending_commit = True
# pylint: disable-next=broad-exception-caught
except Exception as e:
print(e)
return pending_commit


def main():

db_name = select_database()
if db_name is None:
return

# Create the connection URL
connection_url = URL.create(
drivername="postgresql+psycopg2",
username=DB_USER,
password=DB_PASS,
host=DB_HOST,
port=DB_PORT,
database=DB_NAME,
database=db_name,
)

# Create the engine
Expand All @@ -125,7 +200,7 @@ def main():

# Try to connect, just to be sure we can
try:
print(f"Connecting to {DB_NAME} at {DB_HOST}:{DB_PORT}...")
print(f"Connecting to {db_name} at {DB_HOST}:{DB_PORT}...")
with engine.connect() as connection:
print("Successfully connected.")
# pylint: disable-next=broad-exception-caught
Expand All @@ -151,34 +226,7 @@ def main():
except EOFError:
print("\nExiting.")
sys.exit()
if inp.lower() in ("quit", "exit"):
sys.exit()
elif inp.lower() == "commit":
connection.commit()
print("Changes commited.")
pending_commit = False
elif inp.lower() == "rollback":
connection.rollback()
print("Changes rolled back.")
pending_commit = False
else:
# SQL (possibly multi-line) command
while not inp.endswith(";"):
try:
inp += " " + input("...>")
except EOFError:
inp = ""
break
if not inp:
continue
try:
result = connection.execute(sql(inp))
print_results(result)
if not result.returns_rows and result.rowcount > 0:
pending_commit = True
# pylint: disable-next=broad-exception-caught
except Exception as e:
print(e)
pending_commit = do_command(inp, connection, pending_commit)


if __name__ == "__main__":
Expand Down
84 changes: 46 additions & 38 deletions scripts/precomp_lines_to_cave.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
This script takes lines from a precomputed annotations file,
and stuffs them into a synapses table in CAVE.
"""
import readline
import sys
from typing import Dict, List

Expand Down Expand Up @@ -142,6 +143,50 @@ def create_batch_query(batch_items: List[Dict]) -> str:
print(f"Committed {len(items)} new rows to {table_name}")


def load_annotations():
layer = None
items = None
print("Enter Neuroglancer state link or ID, or a GS file path:")
inp = input("> ")
if inp.startswith("gs:"):
layer = build_annotation_layer(inp, mode="read")
else:
verify_cave_auth()
state_id = inp.split("/")[-1] # in case full URL was given

assert client is not None
state = client.state.get_state_json(state_id)
print("Select annotation layer containing synapses to import:")
anno_layer_name = get_annotation_layer_name(state)
data = nglui.parser.get_layer(state, anno_layer_name)

if "annotations" in data:
items = data["annotations"]
elif "source" in data:
print("Precomputed annotation layer.")
layer = build_annotation_layer(data["source"], mode="read")
else:
print("Neither 'annotations' nor 'source' found in layer data. I'm stumped.")
sys.exit()
if items is None and layer is not None:
opt = ""
while opt not in ("A", "B"):
opt = input("Read [A]ll lines, or only within some [B]ounds? ").upper()
if opt == "B":
bbox_start = input_vec3Di(" Bounds start")
bbox_end = input_vec3Di(" Bounds end")
resolution = input_vec3Di(" Resolution")
index = VolumetricIndex.from_coords(bbox_start, bbox_end, resolution)
lines = layer.read_in_bounds(index, strict=True)
else:
lines = layer.read_all()
items = [
{"id": hex(l.id)[2:], "type": "line", "pointA": l.start, "pointB": l.end}
for l in lines
]
return items


def main():
# Create the connection URL
connection_url = URL.create(
Expand Down Expand Up @@ -170,44 +215,7 @@ def main():
print(e)
sys.exit()

# Connect to CAVE so we can get the Neurglancer state
verify_cave_auth()

# Stored state URL be like:
# https://spelunker.cave-explorer.org/#!middleauth+https://global.daf-apis.com/nglstate/api/v1/4542084277075968
# ID is the last part of this.
state_id = input("Neuroglancer state link or ID: ")
state_id = state_id.split("/")[-1] # in case full URL was given

assert client is not None
state = client.state.get_state_json(state_id)
print("Select annotation layer containing synapses to import:")
anno_layer_name = get_annotation_layer_name(state)
data = nglui.parser.get_layer(state, anno_layer_name)

if "annotations" in data:
items = data["annotations"]
elif "source" in data:
print("Precomputed annotation layer.")
layer = build_annotation_layer(data["source"], mode="read")
opt = ""
while opt not in ("A", "B"):
opt = input("Read [A]ll lines, or only within some [B]ounds? ").upper()
if opt == "B":
bbox_start = input_vec3Di(" Bounds start")
bbox_end = input_vec3Di(" Bounds end")
resolution = input_vec3Di(" Resolution")
index = VolumetricIndex.from_coords(bbox_start, bbox_end, resolution)
lines = layer.read_in_bounds(index, strict=True)
else:
lines = layer.read_all()
items = [
{"id": hex(l.id)[2:], "type": "line", "pointA": l.start, "pointB": l.end}
for l in lines
]
else:
print("Neither 'annotations' nor 'source' found in layer data. I'm stumped.")
sys.exit()
items = load_annotations()

print(f"{len(items)} annotations ready to export.")
table_name = input("Table name: ")
Expand Down
1 change: 1 addition & 0 deletions scripts/scale_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from one resolution to another.
"""

import readline
from collections import namedtuple
from pathlib import Path

Expand Down
20 changes: 8 additions & 12 deletions scripts/synapse_sv_lookup/cave_synapse_seg_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@
DB_HOST = "127.0.0.1" # Local proxy address; run Cloud SQL Auth Proxy
DB_PORT = 5432 # Default PostgreSQL port


# Segmentation parameters
# pylint: disable-next=line-too-long
# pylint: disable=global-statement

# Globals
engine = None
Expand All @@ -41,7 +39,6 @@
synapse_resolution = Vec3D(0, 0, 0)
seg_layer = None
seg_bounds = None
seg_path = "" # e.g. gs://zetta_ws/dacey_human_fovea_2404

seg_resolution = Vec3D(0, 0, 0)
seg_chunk_size = Vec3D(0, 0, 0)
Expand Down Expand Up @@ -253,7 +250,7 @@ def input_vec3D(prompt="", default=None):
try:
x, y, z = map(float, s.replace(",", " ").split())
return Vec3D(x, y, z)
except:
except: # pylint: disable=bare-except
print("Enter x, y, and z values separated by commas or spaces.")


Expand Down Expand Up @@ -292,7 +289,7 @@ def lookup_segment_id(seg_point: Vec3D, load_data_if_needed: bool = False):
assert seg_layer[chunk_bounds] is not None # type: ignore[unreachable]
chunk = seg_layer[chunk_bounds][0]
# print('Chunk loaded.')
relative_point = floor(seg_point - chunk_bounds.start) # type: ignore[unreachable] # Geez mypy is stupid.
relative_point = floor(seg_point - chunk_bounds.start) # type: ignore[unreachable]
return chunk[relative_point[0], relative_point[1], relative_point[2]]


Expand Down Expand Up @@ -355,7 +352,7 @@ def main():
# Try to connect, just to be sure we can
try:
print(f"Connecting to {DB_NAME} at {DB_HOST}:{DB_PORT}...")
with engine.connect() as connection:
with engine.connect() as _:
print("Successfully connected to the database!")

# pylint: disable-next=broad-exception-caught
Expand All @@ -370,7 +367,8 @@ def main():
if len(supervox_table.split("__")) == 2:
break
print(
"This should look something like, for example: ipl_ribbon_synapses__dacey_human_fovea_2404"
"This should look something like, for example: "
"ipl_ribbon_synapses__dacey_human_fovea_2404"
)
synapse_table = supervox_table.split("__")[0]
synapse_resolution = input_vec3Di("Synapse voxel scale (resolution)")
Expand Down Expand Up @@ -398,17 +396,15 @@ def main():
where_clause = f"""
NOT EXISTS (SELECT 1 FROM {supervox_table} b WHERE b.id = {synapse_table}.id);
"""
where_clause = "" # HACK!!!
to_do = read_synapses(conn1, synapse_table, where_clause)

print(f"Read {len(to_do)} synapses from {synapse_table}")
print(f"Sorting synapses into chunks...")
print("Sorting synapses into chunks...")
bins = bin_synapses(to_do)
print(f"Divided them into {len(bins)} bins")
sorted_indexes = sorted(bins.keys())
for i in range(0, len(sorted_indexes)):
for i, bin_idx in enumerate(sorted_indexes):
print()
bin_idx = sorted_indexes[i]
print(f"STARTING BIN {i}/{len(sorted_indexes)} (bin {bin_idx})...")
process_synapses(bins[bin_idx])

Expand Down

0 comments on commit ff52457

Please sign in to comment.