import argparse import asyncio import traceback import asyncpg import asyncpg_trek as pgtrek import dotenv from asyncpg_trek.asyncpg import AsyncpgBackend env_cfg: dict[str, str | None] = {} def parse_args(): parser = argparse.ArgumentParser( description="Migrates the database from the current revision to a new specified one.", ) parser.add_argument( "revision", default="NEWEST", help="the revision to upgrade/downgrade to. defaults to the newest one via the constant 'NEWEST'", ) parser.add_argument( "-f", "--folder", default="migrations", help="the folder where migrations are stored", ) parser.add_argument( "-e", "--env", default=".env", help="the .env file for the database. this is used to log into the database", ) parser.add_argument( "--dry-run", action="store_true", help="whether to perform a dry run of this script", ) parser.add_argument( "-d", "--direction", default="up", help="which direction the migration from the current to the specified goes in. defaults to up", ) return parser.parse_args() async def main(args: argparse.Namespace): if (db_url := env_cfg.get("DATABASE")) is None: raise RuntimeError() dir: pgtrek.Direction if args.direction.lower() == "down": dir = pgtrek.Direction.down else: dir = pgtrek.Direction.up conn = await asyncpg.connect(db_url) backend = AsyncpgBackend(conn) try: plans = await pgtrek.plan( backend, args.folder, target_revision=args.revision, direction=dir, ) await pgtrek.execute(backend, plans) except Exception as e: print("oh nyo, something went wrong!") traceback.print_exception(e) return print("all done! :D") if __name__ == "__main__": args = parse_args() env_cfg = dotenv.dotenv_values(args.env) asyncio.run(main(args))