diff --git a/categories.txt b/categories.txt new file mode 100644 index 0000000..c6b958f --- /dev/null +++ b/categories.txt @@ -0,0 +1,10 @@ +House payments +Utilities +Food +Transport +Health +Subscriptions +Insurance +Entertainment +Other +Unknown \ No newline at end of file diff --git a/config.py b/config.py index 318e69a..1c7ccd7 100644 --- a/config.py +++ b/config.py @@ -1,6 +1,10 @@ """ Generic config variables """ + from pathlib import Path PROJECT_FOLDER = Path(__file__).parent.resolve() + +OUTPUT_FOLDER = PROJECT_FOLDER / "output" +OUTPUT_FOLDER.mkdir(exist_ok=True) diff --git a/main.py b/main.py index 1e9393f..6d97bd3 100644 --- a/main.py +++ b/main.py @@ -1,62 +1,35 @@ -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns - -# Load datai -df = pd.read_csv("data/NL38INGB0001546874_01-01-2024_18-08-2024.csv", - sep=";") -df = df[["Date", "Name / Description", "Amount (EUR)", "Transaction type", "Notifications"]] - -# Rename -df = df.rename(columns={"Date": "date", - "Name / Description": "description", - "Amount (EUR)": "amount", - "Transaction type": "category", - "Notifications": "extra"}) - -# Set date format -df["date"] = pd.to_datetime(df["date"], format="%Y%m%d") -# To numeric -df["amount"] = df["amount"].str.replace(",", ".") -df["amount"] = pd.to_numeric(df["amount"]) - -# Get unique transaction categories -categories = df["category"].unique().tolist() -print(categories) - -# Split -plt.pie(df.groupby("category")["amount"].sum(), labels=categories, autopct='%.0f%%') -plt.show() - -# Categorize -""" -Online Banking = Moving between my accounts | Manual payments to other accounts -Batch payment = Retour. Only 1 instance -Cash machine = Cash withdrawal (ATM) -Deposit = Cash deposit (ATM) -Transfer = Rounding feature from savings | Retour -SEPA direct debit = Automatic debit payments -Various = ING account payments | Credit card repayment -Payment terminal = Payments at card machine -iDEAL = Payments via iDEAL - -____________________________________________________________________ - -Drop: - -- Batch payment -- Deposit -- Transfer - -Split: -- Online banking --> Remove if extra contains 'From Oranje spaarrekening' - -Keep: -- Online banking -- SEPA direct debit -- Various -- Payment terminal -- iDEAL -- Cash machine -""" +import json + +from config import PROJECT_FOLDER, OUTPUT_FOLDER +from main.load import ing_loader +from transformers import pipeline + +# Load and preprocess data +df = ing_loader( + PROJECT_FOLDER / "data" / "NL38INGB0001546874_01-01-2024_18-08-2024.csv" +) + +# Set up zero-shot classifier +zeroshot_classifier = pipeline( + "zero-shot-classification", + model="MoritzLaurer/deberta-v3-large-zeroshot-v2.0", + device=0, +) + +# Read budget categories from categories.txt +with open(PROJECT_FOLDER / "categories.txt", "r") as f: + budget_categories = [line.strip() for line in f.readlines()] + +# Set up json file to store transaction classifications +json_file_path = OUTPUT_FOLDER / "transaction_classifications.json" + +# Initialize an empty dictionary to store classifications +transaction_classifications = {} + +# Load existing classifications if the file exists +if json_file_path.exists(): + with json_file_path.open("r") as json_file: + transaction_classifications = json.load(json_file) + +# Iterate over each row ... diff --git a/main/load.py b/main/load.py new file mode 100644 index 0000000..6479668 --- /dev/null +++ b/main/load.py @@ -0,0 +1,73 @@ +""" +Contains loaders for different data sources (file only) + +Output will result in unified dataframe: +- date +- account_name +- account_number +- amount +- description +- category (if available) +""" + +import pandas as pd + + +def ing_loader(file_path): + """Loader for ING bank transaction export + + Online Banking = Moving between accounts | Manual payments to other accounts + Batch payment = Retour. Only 1 instance + Cash machine = Cash withdrawal (ATM) + Deposit = Cash deposit (ATM) + Transfer = Rounding feature from savings | Retour + SEPA direct debit = Automatic debit payments + Various = ING account payments | Credit card repayment + Payment terminal = Payments at card machine + iDEAL = Payments via iDEAL + ____________________________________________________________________ + Drop: + - Batch payment + - Deposit + - Transfer + + Split: + - Online banking --> Remove if extra contains 'From Oranje spaarrekening' + """ + # Load and preprocess data + df = pd.read_csv(file_path, sep=";") + df = df[ + [ + "Date", + "Name / Description", + "Account", + "Amount (EUR)", + "Transaction type", + "Notifications", + ] + ] + + # Rename columns for consistency and clarity + df = df.rename( + columns={ + "Date": "date", + "Name / Description": "account_name", + "Account": "account_number", + "Amount (EUR)": "amount", + "Transaction type": "category", + "Notifications": "description", + } + ) + + # Convert date and amount columns + df["date"] = pd.to_datetime(df["date"], format="%Y%m%d") + df["amount"] = pd.to_numeric(df["amount"].str.replace(",", ".")) + + # Filter transactions + df = df[~df["category"].isin(["Batch payment", "Deposit", "Transfer"])] + mask = (df["category"] == "Online Banking") & ( + df["extra"].str.contains("Oranje spaarrekening", case=False, na=False) + ) + df = df[~mask] + + return df diff --git a/main/main.ipynb b/main/main.ipynb deleted file mode 100644 index 29cd743..0000000 --- a/main/main.ipynb +++ /dev/null @@ -1,35 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "initial_id", - "metadata": { - "collapsed": true - }, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/poetry.lock b/poetry.lock index 124e824..ee3282e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -400,6 +400,23 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] +[[package]] +name = "jinja2" +version = "3.1.4" +description = "A very fast and expressive template engine." +optional = false +python-versions = ">=3.7" +files = [ + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, +] + +[package.dependencies] +MarkupSafe = ">=2.0" + +[package.extras] +i18n = ["Babel (>=2.7)"] + [[package]] name = "kiwisolver" version = "1.4.5" @@ -513,6 +530,75 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] +[[package]] +name = "markupsafe" +version = "2.1.5" +description = "Safely add untrusted strings to HTML/XML markup." +optional = false +python-versions = ">=3.7" +files = [ + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a17a92de5231666cfbe003f0e4b9b3a7ae3afb1ec2845aadc2bacc93ff85febc"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:72b6be590cc35924b02c78ef34b467da4ba07e4e0f0454a2c5907f473fc50ce5"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2174c595a0d73a3080ca3257b40096db99799265e1c27cc5a610743acd86d62f"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ae2ad8ae6ebee9d2d94b17fb62763125f3f374c25618198f40cbb8b525411900"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:075202fa5b72c86ad32dc7d0b56024ebdbcf2048c0ba09f1cde31bfdd57bcfff"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:598e3276b64aff0e7b3451b72e94fa3c238d452e7ddcd893c3ab324717456bad"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fce659a462a1be54d2ffcacea5e3ba2d74daa74f30f5f143fe0c58636e355fdd"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win32.whl", hash = "sha256:d9fad5155d72433c921b782e58892377c44bd6252b5af2f67f16b194987338a4"}, + {file = "MarkupSafe-2.1.5-cp310-cp310-win_amd64.whl", hash = "sha256:bf50cd79a75d181c9181df03572cdce0fbb75cc353bc350712073108cba98de5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:629ddd2ca402ae6dbedfceeba9c46d5f7b2a61d9749597d4307f943ef198fc1f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b7b716f97b52c5a14bffdf688f971b2d5ef4029127f1ad7a513973cfd818df2"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ec585f69cec0aa07d945b20805be741395e28ac1627333b1c5b0105962ffced"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b91c037585eba9095565a3556f611e3cbfaa42ca1e865f7b8015fe5c7336d5a5"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7502934a33b54030eaf1194c21c692a534196063db72176b0c4028e140f8f32c"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0e397ac966fdf721b2c528cf028494e86172b4feba51d65f81ffd65c63798f3f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c061bb86a71b42465156a3ee7bd58c8c2ceacdbeb95d05a99893e08b8467359a"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3a57fdd7ce31c7ff06cdfbf31dafa96cc533c21e443d57f5b1ecc6cdc668ec7f"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win32.whl", hash = "sha256:397081c1a0bfb5124355710fe79478cdbeb39626492b15d399526ae53422b906"}, + {file = "MarkupSafe-2.1.5-cp311-cp311-win_amd64.whl", hash = "sha256:2b7c57a4dfc4f16f7142221afe5ba4e093e09e728ca65c51f5620c9aaeb9a617"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:8dec4936e9c3100156f8a2dc89c4b88d5c435175ff03413b443469c7c8c5f4d1"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3c6b973f22eb18a789b1460b4b91bf04ae3f0c4234a0a6aa6b0a92f6f7b951d4"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac07bad82163452a6884fe8fa0963fb98c2346ba78d779ec06bd7a6262132aee"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f5dfb42c4604dddc8e4305050aa6deb084540643ed5804d7455b5df8fe16f5e5"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ea3d8a3d18833cf4304cd2fc9cbb1efe188ca9b5efef2bdac7adc20594a0e46b"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d050b3361367a06d752db6ead6e7edeb0009be66bc3bae0ee9d97fb326badc2a"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:bec0a414d016ac1a18862a519e54b2fd0fc8bbfd6890376898a6c0891dd82e9f"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:58c98fee265677f63a4385256a6d7683ab1832f3ddd1e66fe948d5880c21a169"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win32.whl", hash = "sha256:8590b4ae07a35970728874632fed7bd57b26b0102df2d2b233b6d9d82f6c62ad"}, + {file = "MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c8b29db45f8fe46ad280a7294f5c3ec36dbac9491f2d1c17345be8e69cc5928f"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec6a563cff360b50eed26f13adc43e61bc0c04d94b8be985e6fb24b81f6dcfdf"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a549b9c31bec33820e885335b451286e2969a2d9e24879f83fe904a5ce59d70a"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4f11aa001c540f62c6166c7726f71f7573b52c68c31f014c25cc7901deea0b52"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:7b2e5a267c855eea6b4283940daa6e88a285f5f2a67f2220203786dfa59b37e9"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:2d2d793e36e230fd32babe143b04cec8a8b3eb8a3122d2aceb4a371e6b09b8df"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ce409136744f6521e39fd8e2a24c53fa18ad67aa5bc7c2cf83645cce5b5c4e50"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win32.whl", hash = "sha256:4096e9de5c6fdf43fb4f04c26fb114f61ef0bf2e5604b6ee3019d51b69e8c371"}, + {file = "MarkupSafe-2.1.5-cp37-cp37m-win_amd64.whl", hash = "sha256:4275d846e41ecefa46e2015117a9f491e57a71ddd59bbead77e904dc02b1bed2"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:656f7526c69fac7f600bd1f400991cc282b417d17539a1b228617081106feb4a"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:97cafb1f3cbcd3fd2b6fbfb99ae11cdb14deea0736fc2b0952ee177f2b813a46"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f3fbcb7ef1f16e48246f704ab79d79da8a46891e2da03f8783a5b6fa41a9532"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa9db3f79de01457b03d4f01b34cf91bc0048eb2c3846ff26f66687c2f6d16ab"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ffee1f21e5ef0d712f9033568f8344d5da8cc2869dbd08d87c84656e6a2d2f68"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5dedb4db619ba5a2787a94d877bc8ffc0566f92a01c0ef214865e54ecc9ee5e0"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:30b600cf0a7ac9234b2638fbc0fb6158ba5bdcdf46aeb631ead21248b9affbc4"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8dd717634f5a044f860435c1d8c16a270ddf0ef8588d4887037c5028b859b0c3"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win32.whl", hash = "sha256:daa4ee5a243f0f20d528d939d06670a298dd39b1ad5f8a72a4275124a7819eff"}, + {file = "MarkupSafe-2.1.5-cp38-cp38-win_amd64.whl", hash = "sha256:619bc166c4f2de5caa5a633b8b7326fbe98e0ccbfacabd87268a2b15ff73a029"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7a68b554d356a91cce1236aa7682dc01df0edba8d043fd1ce607c49dd3c1edcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:db0b55e0f3cc0be60c1f19efdde9a637c32740486004f20d1cff53c3c0ece4d2"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3e53af139f8579a6d5f7b76549125f0d94d7e630761a2111bc431fd820e163b8"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:17b950fccb810b3293638215058e432159d2b71005c74371d784862b7e4683f3"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c31f53cdae6ecfa91a77820e8b151dba54ab528ba65dfd235c80b086d68a465"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bff1b4290a66b490a2f4719358c0cdcd9bafb6b8f061e45c7a2460866bf50c2e"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bc1667f8b83f48511b94671e0e441401371dfd0f0a795c7daa4a3cd1dde55bea"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5049256f536511ee3f7e1b3f87d1d1209d327e818e6ae1365e8653d7e3abb6a6"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win32.whl", hash = "sha256:00e046b6dd71aa03a41079792f8473dc494d564611a8f89bbbd7cb93295ebdcf"}, + {file = "MarkupSafe-2.1.5-cp39-cp39-win_amd64.whl", hash = "sha256:fa173ec60341d6bb97a89f5ea19c85c5643c1e7dedebc22f5181eb73573142c5"}, + {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, +] + [[package]] name = "matplotlib" version = "3.9.2" @@ -576,6 +662,41 @@ python-dateutil = ">=2.7" [package.extras] dev = ["meson-python (>=0.13.1)", "numpy (>=1.25)", "pybind11 (>=2.6)", "setuptools (>=64)", "setuptools_scm (>=7)"] +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + +[[package]] +name = "networkx" +version = "3.3" +description = "Python package for creating and manipulating graphs and networks" +optional = false +python-versions = ">=3.10" +files = [ + {file = "networkx-3.3-py3-none-any.whl", hash = "sha256:28575580c6ebdaf4505b22c6256a2b9de86b316dc63ba9e93abde3d78dfdbcf2"}, + {file = "networkx-3.3.tar.gz", hash = "sha256:0c127d8b2f4865f59ae9cb8aafcd60b5c70f3241ebd66f7defad7c4ab90126c9"}, +] + +[package.extras] +default = ["matplotlib (>=3.6)", "numpy (>=1.23)", "pandas (>=1.4)", "scipy (>=1.9,!=1.11.0,!=1.11.1)"] +developer = ["changelist (==0.5)", "mypy (>=1.1)", "pre-commit (>=3.2)", "rtoml"] +doc = ["myst-nb (>=1.0)", "numpydoc (>=1.7)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.14)", "sphinx (>=7)", "sphinx-gallery (>=0.14)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=2.0)", "pygraphviz (>=1.12)", "sympy (>=1.10)"] +test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"] + [[package]] name = "nodeenv" version = "1.8.0" @@ -651,6 +772,150 @@ files = [ {file = "numpy-2.1.0.tar.gz", hash = "sha256:7dc90da0081f7e1da49ec4e398ede6a8e9cc4f5ebe5f9e06b443ed889ee9aaa2"}, ] +[[package]] +name = "nvidia-cublas-cu11" +version = "11.11.3.6" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux1_x86_64.whl", hash = "sha256:39fb40e8f486dd8a2ddb8fdeefe1d5b28f5b99df01c87ab3676f057a74a5a6f3"}, + {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5ccae9e069a2c6be9af9cb5a0b0c6928c19c7915e390d15f598a1eead2a01a7a"}, + {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-manylinux2014_x86_64.whl", hash = "sha256:60252822adea5d0b10cd990a7dc7bedf7435f30ae40083c7a624a85a43225abc"}, + {file = "nvidia_cublas_cu11-11.11.3.6-py3-none-win_amd64.whl", hash = "sha256:6ab12b1302bef8ac1ff4414edd1c059e57f4833abef9151683fb8f4de25900be"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu11" +version = "11.8.87" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux1_x86_64.whl", hash = "sha256:0e50c707df56c75a2c0703dc6b886f3c97a22f37d6f63839f75b7418ba672a8d"}, + {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux2014_aarch64.whl", hash = "sha256:9aaa638963a8271df26b6ee0d93b781730b7acc6581ff700bd023d7934e4385e"}, + {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-manylinux2014_x86_64.whl", hash = "sha256:4191a17913a706b5098681280cd089cd7d8d3df209a6f5cb79384974a96d24f2"}, + {file = "nvidia_cuda_cupti_cu11-11.8.87-py3-none-win_amd64.whl", hash = "sha256:4332d8550ad5f5b673f98d08e4e4f82030cb604c66d8d5ee919399ea01312e58"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu11" +version = "11.8.89" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl", hash = "sha256:1f27d67b0f72902e9065ae568b4f6268dfe49ba3ed269c9a3da99bb86d1d2008"}, + {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux2014_aarch64.whl", hash = "sha256:8ab17ed51e7c4928f7170a0551e3e3b42f89d973bd267ced9688c238b3e10aef"}, + {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a8d02f3cba345be56b1ffc3e74d8f61f02bb758dd31b0f20e12277a5a244f756"}, + {file = "nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-win_amd64.whl", hash = "sha256:e18a23a8f4064664a6f1c4a64f38c581cbebfb5935280e94a4943ea8ae3791b1"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu11" +version = "11.8.89" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux1_x86_64.whl", hash = "sha256:f587bd726eb2f7612cf77ce38a2c1e65cf23251ff49437f6161ce0d647f64f7c"}, + {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_aarch64.whl", hash = "sha256:e53bf160b6b660819cb6e4a9da2cc89e6aa2329858299780a2459780a2b8d012"}, + {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-manylinux2014_x86_64.whl", hash = "sha256:92d04069a987e1fbc9213f8376d265df0f7bb42617d44f5eda1f496acea7f2d1"}, + {file = "nvidia_cuda_runtime_cu11-11.8.89-py3-none-win_amd64.whl", hash = "sha256:f60c9fdaed8065b38de8097867240efc5556a8a710007146daeb9082334a6e63"}, +] + +[[package]] +name = "nvidia-cudnn-cu11" +version = "9.1.0.70" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu11-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e6135ac63fe9d5b0b89cfb35c3fc1c1349f2b995becadf2e9dc21bca89d9633d"}, + {file = "nvidia_cudnn_cu11-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:32f6a2fe80b4b7ebc5f9c4cb403c4c381eca99e6daa3cf38241047b3d3e14daa"}, +] + +[package.dependencies] +nvidia-cublas-cu11 = "*" + +[[package]] +name = "nvidia-cufft-cu11" +version = "10.9.0.58" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl", hash = "sha256:222f9da70c80384632fd6035e4c3f16762d64ea7a843829cb278f98b3cb7dd81"}, + {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux2014_aarch64.whl", hash = "sha256:34b7315104e615b230dc3c2d1861f13bff9ec465c5d3b4bb65b4986d03a1d8d4"}, + {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e21037259995243cc370dd63c430d77ae9280bedb68d5b5a18226bfc92e5d748"}, + {file = "nvidia_cufft_cu11-10.9.0.58-py3-none-win_amd64.whl", hash = "sha256:c4d316f17c745ec9c728e30409612eaf77a8404c3733cdf6c9c1569634d1ca03"}, +] + +[[package]] +name = "nvidia-curand-cu11" +version = "10.3.0.86" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu11-10.3.0.86-py3-none-manylinux1_x86_64.whl", hash = "sha256:ac439548c88580269a1eb6aeb602a5aed32f0dbb20809a31d9ed7d01d77f6bf5"}, + {file = "nvidia_curand_cu11-10.3.0.86-py3-none-manylinux2014_aarch64.whl", hash = "sha256:64defc3016d8c1de351a764617818c2961210430f12476faee10084b269b188c"}, + {file = "nvidia_curand_cu11-10.3.0.86-py3-none-manylinux2014_x86_64.whl", hash = "sha256:cd4cffbf78bb06580206b4814d5dc696d1161c902aae37b2bba00056832379e6"}, + {file = "nvidia_curand_cu11-10.3.0.86-py3-none-win_amd64.whl", hash = "sha256:8fa8365065fc3e3760d7437b08f164a6bcf8f7124f3b544d2463ded01e6bdc70"}, +] + +[[package]] +name = "nvidia-cusolver-cu11" +version = "11.4.1.48" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux1_x86_64.whl", hash = "sha256:ca538f545645b7e6629140786d3127fe067b3d5a085bd794cde5bfe877c8926f"}, + {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1a96acb05768610bc414dbef5b25ebd2d820fc8a1e8c72097f41f53d80934d61"}, + {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea9fb1ad8c644ca9ed55af13cc39af3b7ba4c3eb5aef18471fe1fe77d94383cb"}, + {file = "nvidia_cusolver_cu11-11.4.1.48-py3-none-win_amd64.whl", hash = "sha256:7efe43b113495a64e2cf9a0b4365bd53b0a82afb2e2cf91e9f993c9ef5e69ee8"}, +] + +[package.dependencies] +nvidia-cublas-cu11 = "*" + +[[package]] +name = "nvidia-cusparse-cu11" +version = "11.7.5.86" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux1_x86_64.whl", hash = "sha256:4ae709fe78d3f23f60acaba8c54b8ad556cf16ca486e0cc1aa92dca7555d2d2b"}, + {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux2014_aarch64.whl", hash = "sha256:6c7da46abee7567e619d4aa2e90a1b032cfcbd1211d429853b1a6e87514a14b2"}, + {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8d7cf1628fd8d462b5d2ba6678fae34733a48ecb80495b9c68672ec6a6dde5ef"}, + {file = "nvidia_cusparse_cu11-11.7.5.86-py3-none-win_amd64.whl", hash = "sha256:a0f6ee81cd91be606fc2f55992d06b09cd4e86d74b6ae5e8dd1631cf7f5a8706"}, +] + +[[package]] +name = "nvidia-nccl-cu11" +version = "2.20.5" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu11-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:3619e25dfb0c8f4c554561c3459ee7dfe1250eed05e9aa4d147a75c45cc6ae0d"}, +] + +[[package]] +name = "nvidia-nvtx-cu11" +version = "11.8.86" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu11-11.8.86-py3-none-manylinux1_x86_64.whl", hash = "sha256:890656d8bd9b4e280231c832e1f0d03459200ba4824ddda3dcb59b1e1989b9f5"}, + {file = "nvidia_nvtx_cu11-11.8.86-py3-none-manylinux2014_aarch64.whl", hash = "sha256:5e84b97062eb102b45a8a9172a06cfe28b239b1635075a13d6474e91295e0468"}, + {file = "nvidia_nvtx_cu11-11.8.86-py3-none-manylinux2014_x86_64.whl", hash = "sha256:979f5b2aef5da164c5c53c64c85c3dfa61b8b4704f4f963bb568bf98fa8472e8"}, + {file = "nvidia_nvtx_cu11-11.8.86-py3-none-win_amd64.whl", hash = "sha256:54031010ee38d774b2991004d88f90bbd7bbc1458a96bbc4b42662756508c252"}, +] + [[package]] name = "packaging" version = "24.0" @@ -1250,6 +1515,23 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sympy" +version = "1.13.2" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "sympy-1.13.2-py3-none-any.whl", hash = "sha256:c51d75517712f1aed280d4ce58506a4a88d635d6b5dd48b39102a7ae1f3fcfe9"}, + {file = "sympy-1.13.2.tar.gz", hash = "sha256:401449d84d07be9d0c7a46a64bd54fe097667d5e7181bfe67ec777be9e01cb13"}, +] + +[package.dependencies] +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] + [[package]] name = "tokenizers" version = "0.19.1" @@ -1367,6 +1649,54 @@ dev = ["tokenizers[testing]"] docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] +[[package]] +name = "torch" +version = "2.4.1+cu118" +description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "torch-2.4.1+cu118-cp310-cp310-linux_x86_64.whl", hash = "sha256:740bae6eb10c6b41cb86c4f9e84da0b4533b5595aed4f06694d95d5e32b4076c"}, + {file = "torch-2.4.1+cu118-cp310-cp310-win_amd64.whl", hash = "sha256:08634e2d32e753ea4a086d42cc145f9251c767b0a7619fd9a6ed5c035dee7b63"}, + {file = "torch-2.4.1+cu118-cp311-cp311-linux_x86_64.whl", hash = "sha256:c7fbf1e214af65ccc0e54d265140b2d09486f977b966fcde218b25068bd54551"}, + {file = "torch-2.4.1+cu118-cp311-cp311-win_amd64.whl", hash = "sha256:1db7ac11687fc6ec279c5504302840a057a4f6bbadf81b7c4588fa53743cf493"}, + {file = "torch-2.4.1+cu118-cp312-cp312-linux_x86_64.whl", hash = "sha256:5c08fa312d259dd19dbd5058d3e82992b27f092347fccc7c0f417f9e0ac16956"}, + {file = "torch-2.4.1+cu118-cp312-cp312-win_amd64.whl", hash = "sha256:1ea70f54853744dd097992ef88d86d4ab793df0df4c41d1e37313ef4b80ef93f"}, + {file = "torch-2.4.1+cu118-cp38-cp38-linux_x86_64.whl", hash = "sha256:a8ba825b2c00274db8924b8c3e860a8d42947d2f04fcdf5b220ad7a650a83dea"}, + {file = "torch-2.4.1+cu118-cp38-cp38-win_amd64.whl", hash = "sha256:1520c0a9aa6d0187c9617b07409c9493d0bf20b28f26cffa3458995f53f58c48"}, + {file = "torch-2.4.1+cu118-cp39-cp39-linux_x86_64.whl", hash = "sha256:82309da0cce45cf61eb48b0567c7d080992d8ba98264da128ee1d858fac5dd75"}, + {file = "torch-2.4.1+cu118-cp39-cp39-win_amd64.whl", hash = "sha256:0eb4393f51f110e5d9f20e3e8079b4bd1fc8fa781c4ac521f709065f498be676"}, +] + +[package.dependencies] +filelock = "*" +fsspec = "*" +jinja2 = "*" +networkx = "*" +nvidia-cublas-cu11 = {version = "11.11.3.6", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu11 = {version = "11.8.87", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu11 = {version = "11.8.89", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu11 = {version = "11.8.89", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu11 = {version = "9.1.0.70", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu11 = {version = "10.9.0.58", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu11 = {version = "10.3.0.86", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu11 = {version = "11.4.1.48", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu11 = {version = "11.7.5.86", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu11 = {version = "2.20.5", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu11 = {version = "11.8.86", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +sympy = "*" +triton = {version = "3.0.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.13\""} +typing-extensions = ">=4.8.0" + +[package.extras] +opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.11.0)"] + +[package.source] +type = "legacy" +url = "https://download.pytorch.org/whl/cu118" +reference = "pytorch-gpu-src" + [[package]] name = "tqdm" version = "4.66.5" @@ -1455,6 +1785,28 @@ torchhub = ["filelock", "huggingface-hub (>=0.23.2,<1.0)", "importlib-metadata", video = ["av (==9.2.0)", "decord (==0.6.0)"] vision = ["Pillow (>=10.0.1,<=15.0)"] +[[package]] +name = "triton" +version = "3.0.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a"}, + {file = "triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c"}, + {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, + {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, + {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "llnl-hatchet", "numpy", "pytest", "scipy (>=1.7.1)"] +tutorials = ["matplotlib", "pandas", "tabulate"] + [[package]] name = "typing-extensions" version = "4.12.2" @@ -1517,4 +1869,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = ">=3.12,<3.13" -content-hash = "0be56f9b4a7e4221040062c33f0006f9f02ab8fe4493f62a643786e08746211a" +content-hash = "859b310d5bfe39c7388712a15827ce614c7bf29fdf985673043461211dd1584a" diff --git a/pyproject.toml b/pyproject.toml index 8e5dff2..726d3df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,14 +14,13 @@ matplotlib = "^3.9.2" seaborn = "^0.13.2" transformers = "^4.44.2" pandas = "^2.2.2" +torch = {version = "^2.4.0+cu121", source = "pytorch-gpu-src"} -[tool.black] -line-length = 100 -extend-exclude = '/^[^.]+$|\.(?!(py)$)([^.]+$)/' -[tool.isort] -profile = "black" -line_length = 100 +[[tool.poetry.source]] +name = "pytorch-gpu-src" +url = "https://download.pytorch.org/whl/cu118" +priority = "explicit" [build-system] requires = ["poetry-core"]