Files
PromptTech/backend/migrate_user_table.py

82 lines
2.7 KiB
Python
Raw Normal View History

"""
Database migration script to add email verification and OAuth fields to User table.
Run this script to update your existing database.
"""
import asyncio
from sqlalchemy import text
from database import AsyncSessionLocal
async def migrate_database():
async with AsyncSessionLocal() as session:
print("Starting database migration...")
# Check if columns already exist
check_query = text("""
SELECT column_name
FROM information_schema.columns
WHERE table_name='users'
AND column_name IN ('email_verified', 'verification_token', 'oauth_provider', 'oauth_id');
""")
result = await session.execute(check_query)
existing_columns = [row[0] for row in result.fetchall()]
if 'email_verified' in existing_columns:
print("✓ Columns already exist. Migration not needed.")
return
print("Adding new columns to users table...")
# Add email_verified column
await session.execute(text("""
ALTER TABLE users
ADD COLUMN IF NOT EXISTS email_verified BOOLEAN NOT NULL DEFAULT FALSE;
"""))
print("✓ Added email_verified column")
# Add verification_token column
await session.execute(text("""
ALTER TABLE users
ADD COLUMN IF NOT EXISTS verification_token VARCHAR(500);
"""))
print("✓ Added verification_token column")
# Add oauth_provider column
await session.execute(text("""
ALTER TABLE users
ADD COLUMN IF NOT EXISTS oauth_provider VARCHAR(50);
"""))
print("✓ Added oauth_provider column")
# Add oauth_id column
await session.execute(text("""
ALTER TABLE users
ADD COLUMN IF NOT EXISTS oauth_id VARCHAR(255);
"""))
print("✓ Added oauth_id column")
# Make password nullable for OAuth users
await session.execute(text("""
ALTER TABLE users
ALTER COLUMN password DROP NOT NULL;
"""))
print("✓ Made password column nullable (for OAuth users)")
# Mark all existing users as verified (they registered before verification was added)
await session.execute(text("""
UPDATE users
SET email_verified = TRUE
WHERE email_verified = FALSE;
"""))
print("✓ Marked existing users as verified")
await session.commit()
print("\n✅ Migration completed successfully!")
if __name__ == "__main__":
print("=" * 60)
print("User Table Migration Script")
print("=" * 60)
asyncio.run(migrate_database())