Skip to content

Commit

Permalink
Fixed clippy and formatting of rust services
Browse files Browse the repository at this point in the history
  • Loading branch information
joshniemela committed Aug 18, 2024
1 parent d8a9301 commit 857a5be
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ jobs:
service: [rust_parser, vector_store]
steps:
- uses: actions/checkout@v4
- name: Run clippy
- name: Run clippy on ${{ matrix.service }}
working-directory: backend/${{ matrix.service }}
run: cargo clippy --all-targets --all-features
4 changes: 2 additions & 2 deletions backend/rust_parser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ enum Department {
DrugDesignAndPharmacology,
CellularAndMolecularMedicine,
Pharmacy,
GLOBE,
Globe,
}
impl Department {
fn from_str(s: &str) -> Result<Self> {
Expand Down Expand Up @@ -124,7 +124,7 @@ impl Department {
"Institut for Nordiske Studier og Sprogvidenskab" => {
bail!("Nordic studies not supported <EXPECTED>")
}
"GLOBE Institute" => Ok(Department::GLOBE),
"GLOBE Institute" => Ok(Department::Globe),
_ => bail!("Unknown department: {}", s),
}
}
Expand Down
4 changes: 2 additions & 2 deletions backend/rust_parser/src/parser/course_information.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub fn parse_course_info(dom: &VDom) -> Result<CourseInformation> {
let panel_bodies = dom.get_elements_by_class_name("panel-body");
// there might be multiple panel-bodies, so we need to check each one
// for the dl element (only the course info should have a dl element)
for (_i, panel_body) in panel_bodies.enumerate() {
for panel_body in panel_bodies {
let mut dl_elements = panel_body
.get(parser)
.context("Failed to get panel-body")?
Expand Down Expand Up @@ -380,7 +380,7 @@ fn parse_dl(dl_tag: &tl::HTMLTag, parser: &tl::Parser) -> Result<Vec<(String, St
// for even numbers, we expect a dt element, odd numbers we expect a dd element
// make a pair of precisely two strings
let mut pair: Vec<String> = Vec::with_capacity(2);
for (_i, child) in children.top().iter().enumerate() {
for child in children.top().iter() {
let node = child
.get(parser)
.context("Failed to get node whilst parsing dl")?;
Expand Down
10 changes: 8 additions & 2 deletions backend/rust_parser/src/parser/exam_information.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,16 @@ fn parse_text_to_exam(text: &str) -> Result<Exam> {
_ if exam_name.contains("aflevering") || exam_name.contains("assignment") => {
Ok(Exam::Assignment(exam_minutes))
}
_ if exam_name.contains("skriftlig prøve") || exam_name.contains("skriftlig stedprøve") || exam_name.contains("written exam") => {
_ if exam_name.contains("skriftlig prøve")
|| exam_name.contains("skriftlig stedprøve")
|| exam_name.contains("written exam") =>
{
Ok(Exam::Written(exam_minutes))
}
_ if exam_name.contains("mundtlig prøve") || exam_name.contains("mundtligt forsvar") || exam_name.contains("oral exam") => {
_ if exam_name.contains("mundtlig prøve")
|| exam_name.contains("mundtligt forsvar")
|| exam_name.contains("oral exam") =>
{
Ok(Exam::Oral(exam_minutes))
}
_ if exam_name.contains("portfolio")
Expand Down
55 changes: 34 additions & 21 deletions backend/vector_store/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{Coordinator, Course};
use crate::populate::Document;
use crate::embedding::{CoordinatorEmbedding, CourseEmbedding};
use crate::populate::Document;
use anyhow::Result;
use pgvector::Vector;
use sqlx::postgres::{PgPool, PgPoolOptions};
use sqlx::{Row, query};
use sqlx::{query, Row};

pub struct PostgresDB {
pub pool: PgPool,
Expand All @@ -30,7 +30,9 @@ impl PostgresDB {
WHERE
c.last_modified > COALESCE(te.last_modified, to_timestamp(0)) OR
c.last_modified > COALESCE(ce.last_modified, to_timestamp(0))"
).fetch_all(&self.pool).await?;
)
.fetch_all(&self.pool)
.await?;

let mut ids: Vec<String> = Vec::new();
for row in result {
Expand All @@ -47,8 +49,10 @@ impl PostgresDB {
"SELECT coordinator.email, coordinator.full_name
FROM coordinator
LEFT JOIN name_embedding ne ON coordinator.email = ne.email
WHERE ne.embedding IS NULL"
).fetch_all(&self.pool).await?;
WHERE ne.embedding IS NULL",
)
.fetch_all(&self.pool)
.await?;

let mut coordinators = Vec::new();
for row in result {
Expand Down Expand Up @@ -84,7 +88,6 @@ impl PostgresDB {
Ok(courses)
}


/// Inserts the document into the database
/// If the document already exists, it updates the title, content, and last_modified timestamp
/// This is used by populate.rs but is not strictly required
Expand All @@ -96,15 +99,16 @@ impl PostgresDB {
let result = query!(
"SELECT title, content FROM course WHERE id = $1",
document.info.id
).fetch_optional(&self.pool).await?;
)
.fetch_optional(&self.pool)
.await?;

if let Some(row) = result {
if row.title == document.title && row.content == document.description.content {
return Ok(());
}
}


let mut tx = self.pool.begin().await?;

query!(
Expand Down Expand Up @@ -150,13 +154,18 @@ impl PostgresDB {
/// Inserts the coordinator embedding into the database
/// If the coordinator already exists, it does nothing,
/// this is because we assume the names of the coordinators are immutable
pub async fn insert_coordinator_embedding(&self, coordinator: CoordinatorEmbedding) -> Result<()> {
pub async fn insert_coordinator_embedding(
&self,
coordinator: CoordinatorEmbedding,
) -> Result<()> {
query(
"INSERT INTO name_embedding (email, embedding) VALUES ($1, $2)
ON CONFLICT(email) DO NOTHING")
.bind(coordinator.email)
.bind(Vector::from(coordinator.name.to_owned()))
.execute(&self.pool).await?;
ON CONFLICT(email) DO NOTHING",
)
.bind(coordinator.email)
.bind(Vector::from(coordinator.name.to_owned()))
.execute(&self.pool)
.await?;
Ok(())
}

Expand All @@ -166,17 +175,21 @@ impl PostgresDB {
let mut tx = self.pool.begin().await?;
query(
"INSERT INTO title_embedding (course_id, embedding) VALUES ($1, $2)
ON CONFLICT(course_id) DO UPDATE SET embedding = $2, last_modified = CURRENT_TIMESTAMP")
.bind(&course_embedding.id)
.bind(Vector::from(course_embedding.title.to_owned()))
.execute(&mut *tx).await?;
ON CONFLICT(course_id) DO UPDATE SET embedding = $2, last_modified = CURRENT_TIMESTAMP",
)
.bind(&course_embedding.id)
.bind(Vector::from(course_embedding.title.to_owned()))
.execute(&mut *tx)
.await?;

query(
"INSERT INTO content_embedding (course_id, embedding) VALUES ($1, $2)
ON CONFLICT(course_id) DO UPDATE SET embedding = $2, last_modified = CURRENT_TIMESTAMP")
.bind(course_embedding.id)
.bind(Vector::from(course_embedding.content.to_owned()))
.execute(&mut *tx).await?;
ON CONFLICT(course_id) DO UPDATE SET embedding = $2, last_modified = CURRENT_TIMESTAMP",
)
.bind(course_embedding.id)
.bind(Vector::from(course_embedding.content.to_owned()))
.execute(&mut *tx)
.await?;

tx.commit().await?;
Ok(())
Expand Down
65 changes: 37 additions & 28 deletions backend/vector_store/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use axum::extract::Query;
use axum::extract::State;
use axum::routing::get;
use axum::{Json, Router};
use futures_util::pin_mut;
use futures_util::stream::StreamExt;
use serde::Deserialize;
use sqlx::migrate;
use std::env;
use std::path::Path;
use std::sync::Arc;
use futures_util::pin_mut;
use futures_util::stream::StreamExt;
use sqlx::migrate;

mod db;
use db::PostgresDB;
Expand Down Expand Up @@ -57,7 +57,9 @@ struct SearchQuery {
/// These tasks use the embedder to generate the embeddings
#[tokio::main]
async fn main() {
let conn_string = env::var("POSTGRES_URL").expect("POSTGRES_URL not set, it should be in the format postgres://user:password@host/db");
let conn_string = env::var("POSTGRES_URL").expect(
"POSTGRES_URL not set, it should be in the format postgres://user:password@host/db",
);

let db = PostgresDB::new(&conn_string)
.await
Expand All @@ -70,19 +72,22 @@ async fn main() {
let data_dir = env::var("DATA_DIR").expect("DATA_DIR not set");
let new_json_dir = data_dir.to_owned() + "new_json/";
let path = Path::new(&new_json_dir);
upsert_documents_from_path(&db, path).await.expect("Failed to upsert documents from path into database");
upsert_documents_from_path(&db, path)
.await
.expect("Failed to upsert documents from path into database");

let state = AppState {
db: Arc::new(db),
embedder: Arc::new(Embedder::new())
embedder: Arc::new(Embedder::new()),
};

const SYNC_INTERVAL: u64 = 60 * 60 * 6;

let coordinator_state = state.clone();
tokio::spawn(async move {
loop {
populate_coordinator_embeddings(&coordinator_state.db, &coordinator_state.embedder).await;
populate_coordinator_embeddings(&coordinator_state.db, &coordinator_state.embedder)
.await;
println!("done populating coordinator embeddings");
tokio::time::sleep(tokio::time::Duration::from_secs(SYNC_INTERVAL)).await;
}
Expand All @@ -107,7 +112,9 @@ async fn main() {
.await
.expect("Failed to bind to port");
println!("listening on {}", port);
axum::serve(listener, app).await.expect("Failed to start server, this should not happen");
axum::serve(listener, app)
.await
.expect("Failed to start server, this should not happen");
}

/// Search endpoint that takes a query parameter and returns a list of the course ids that
Expand All @@ -118,52 +125,54 @@ async fn search(
) -> Json<Vec<String>> {
let query_embedding = state.embedder.embed_query(query.query);
let db = &state.db;
let ids = db.get_most_relevant_course_ids(&query_embedding)
let ids = db
.get_most_relevant_course_ids(&query_embedding)
.await
.expect("Failed to get most relevant course ids");
Json(ids)
}

/// Upserts the coordinator embeddings into the database using the coordinator information
/// from the database and the embedder to generate the embeddings
async fn populate_coordinator_embeddings(
db: &PostgresDB,
embedder: &Embedder,
) {
let missing_coordinators = db.get_missing_embedding_email_names().await.expect("Failed to get missing coordinators");
async fn populate_coordinator_embeddings(db: &PostgresDB, embedder: &Embedder) {
let missing_coordinators = db
.get_missing_embedding_email_names()
.await
.expect("Failed to get missing coordinators");

println!("missing coordinators: {}", missing_coordinators.len());

let embedding_stream = embedder.embed_coordinators(missing_coordinators);
pin_mut!(embedding_stream);

while let Some(embedded_coordinator) = embedding_stream.next().await {
db.insert_coordinator_embedding(
embedded_coordinator
).await.expect("Failed to insert coordinator embedding");
db.insert_coordinator_embedding(embedded_coordinator)
.await
.expect("Failed to insert coordinator embedding");
}
}

/// Upserts the course embeddings into the database using the course information
/// from the database and the embedder to generate the embeddings
async fn populate_course_embeddings(
db: &PostgresDB,
embedder: &Embedder,
) {
let outdated_embeddings = db.get_outdated_embedding_course_ids().await.expect("Failed to get outdated embeddings");
async fn populate_course_embeddings(db: &PostgresDB, embedder: &Embedder) {
let outdated_embeddings = db
.get_outdated_embedding_course_ids()
.await
.expect("Failed to get outdated embeddings");

let outdated_courses: Vec<Course> =
db.get_courses_by_ids(&outdated_embeddings).await.expect("Failed to get courses by ids");
let outdated_courses: Vec<Course> = db
.get_courses_by_ids(&outdated_embeddings)
.await
.expect("Failed to get courses by ids");

println!("missing documents: {}", outdated_courses.len());

let embedding_stream = embedder.embed_courses(outdated_courses);
pin_mut!(embedding_stream);

while let Some(embedded_document) = embedding_stream.next().await {

db.insert_course_embedding(
embedded_document
).await.expect("Failed to insert course embedding");
db.insert_course_embedding(embedded_document)
.await
.expect("Failed to insert course embedding");
}
}
4 changes: 2 additions & 2 deletions backend/vector_store/src/populate.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{PostgresDB, Coordinator};
use super::{Coordinator, PostgresDB};
use anyhow::Result;
use nanohtml2text::html2text;
use serde::Deserialize;
use std::fs::File;
use std::io::BufReader;
use std::path::Path;
use serde::Deserialize;

#[derive(Deserialize, Clone)]
pub struct Document {
Expand Down

0 comments on commit 857a5be

Please sign in to comment.