summaryrefslogtreecommitdiff
path: root/scripts/migrate_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/migrate_db.py')
-rw-r--r--scripts/migrate_db.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/scripts/migrate_db.py b/scripts/migrate_db.py
new file mode 100644
index 0000000..b389f60
--- /dev/null
+++ b/scripts/migrate_db.py
@@ -0,0 +1,80 @@
+import argparse
+import asyncio
+import traceback
+
+import asyncpg
+import asyncpg_trek as pgtrek
+import dotenv
+from asyncpg_trek.asyncpg import AsyncpgBackend
+
+env_cfg: dict[str, str | None] = {}
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="Migrates the database from the current revision to a new specified one.",
+ )
+ parser.add_argument(
+ "revision",
+ default="NEWEST",
+ help="the revision to upgrade/downgrade to. defaults to the newest one via the constant 'NEWEST'",
+ )
+ parser.add_argument(
+ "-f",
+ "--folder",
+ default="migrations",
+ help="the folder where migrations are stored",
+ )
+ parser.add_argument(
+ "-e",
+ "--env",
+ default=".env",
+ help="the .env file for the database. this is used to log into the database",
+ )
+ parser.add_argument(
+ "--dry-run",
+ action="store_true",
+ help="whether to perform a dry run of this script",
+ )
+ parser.add_argument(
+ "-d",
+ "--direction",
+ default="up",
+ help="which direction the migration from the current to the specified goes in. defaults to up",
+ )
+ return parser.parse_args()
+
+
+async def main(args: argparse.Namespace):
+ if (db_url := env_cfg.get("DATABASE")) is None:
+ raise RuntimeError()
+
+ dir: pgtrek.Direction
+ if args.direction.lower() == "down":
+ dir = pgtrek.Direction.down
+ else:
+ dir = pgtrek.Direction.up
+
+ conn = await asyncpg.connect(db_url)
+ backend = AsyncpgBackend(conn)
+
+ try:
+ plans = await pgtrek.plan(
+ backend,
+ args.folder,
+ target_revision=args.revision,
+ direction=dir,
+ )
+ await pgtrek.execute(backend, plans)
+ except Exception as e:
+ print("oh nyo, something went wrong!")
+ traceback.print_exception(e)
+ return
+
+ print("all done! :D")
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ env_cfg = dotenv.dotenv_values(args.env)
+ asyncio.run(main(args))