feat: backend support for mandatory password change on first login
This commit is contained in:
@@ -34,6 +34,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
"username": "admin",
|
"username": "admin",
|
||||||
"password_hash": get_password_hash("admin"),
|
"password_hash": get_password_hash("admin"),
|
||||||
"role": "admin",
|
"role": "admin",
|
||||||
|
"must_change_password": True
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
yield
|
yield
|
||||||
@@ -76,6 +77,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str:
|
|||||||
class Token(BaseModel):
|
class Token(BaseModel):
|
||||||
access_token: str
|
access_token: str
|
||||||
token_type: str
|
token_type: str
|
||||||
|
must_change_password: bool = False
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
class LoginRequest(BaseModel):
|
||||||
@@ -83,6 +85,11 @@ class LoginRequest(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
old_password: str
|
||||||
|
new_password: str
|
||||||
|
|
||||||
|
|
||||||
class LogsResponse(BaseModel):
|
class LogsResponse(BaseModel):
|
||||||
total: int
|
total: int
|
||||||
limit: int
|
limit: int
|
||||||
@@ -91,7 +98,7 @@ class LogsResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/v1/auth/login", response_model=Token)
|
@app.post("/api/v1/auth/login", response_model=Token)
|
||||||
async def login(request: LoginRequest) -> dict[str, str]:
|
async def login(request: LoginRequest) -> dict[str, Any]:
|
||||||
user: dict[str, Any] | None = await repo.get_user_by_username(request.username)
|
user: dict[str, Any] | None = await repo.get_user_by_username(request.username)
|
||||||
if not user or not verify_password(request.password, user["password_hash"]):
|
if not user or not verify_password(request.password, user["password_hash"]):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -105,7 +112,25 @@ async def login(request: LoginRequest) -> dict[str, str]:
|
|||||||
access_token: str = create_access_token(
|
access_token: str = create_access_token(
|
||||||
data={"uuid": user["uuid"]}, expires_delta=access_token_expires
|
data={"uuid": user["uuid"]}, expires_delta=access_token_expires
|
||||||
)
|
)
|
||||||
return {"access_token": access_token, "token_type": "bearer"}
|
return {
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"must_change_password": bool(user.get("must_change_password", False))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/v1/auth/change-password")
|
||||||
|
async def change_password(request: ChangePasswordRequest, current_user: str = Depends(get_current_user)) -> dict[str, str]:
|
||||||
|
user: dict[str, Any] | None = await repo.get_user_by_uuid(current_user)
|
||||||
|
if not user or not verify_password(request.old_password, user["password_hash"]):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect old password",
|
||||||
|
)
|
||||||
|
|
||||||
|
new_hash = get_password_hash(request.new_password)
|
||||||
|
await repo.update_user_password(current_user, new_hash, must_change_password=False)
|
||||||
|
return {"message": "Password updated successfully"}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/v1/logs", response_model=LogsResponse)
|
@app.get("/api/v1/logs", response_model=LogsResponse)
|
||||||
|
|||||||
@@ -40,7 +40,17 @@ class BaseRepository(ABC):
|
|||||||
"""Retrieve a user by their username."""
|
"""Retrieve a user by their username."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||||
|
"""Retrieve a user by their UUID."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||||
"""Create a new dashboard user."""
|
"""Create a new dashboard user."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
||||||
|
"""Update a user's password and change the must_change_password flag."""
|
||||||
|
pass
|
||||||
|
|||||||
@@ -29,9 +29,14 @@ class SQLiteRepository(BaseRepository):
|
|||||||
uuid TEXT PRIMARY KEY,
|
uuid TEXT PRIMARY KEY,
|
||||||
username TEXT UNIQUE,
|
username TEXT UNIQUE,
|
||||||
password_hash TEXT,
|
password_hash TEXT,
|
||||||
role TEXT DEFAULT 'viewer'
|
role TEXT DEFAULT 'viewer',
|
||||||
|
must_change_password BOOLEAN DEFAULT 0
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
try:
|
||||||
|
await db.execute("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT 0")
|
||||||
|
except aiosqlite.OperationalError:
|
||||||
|
pass # Column already exists
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def add_log(self, log_data: dict[str, Any]) -> None:
|
async def add_log(self, log_data: dict[str, Any]) -> None:
|
||||||
@@ -112,15 +117,31 @@ class SQLiteRepository(BaseRepository):
|
|||||||
row = await cursor.fetchone()
|
row = await cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]:
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute("SELECT * FROM users WHERE uuid = ?", (uuid,)) as cursor:
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
async def create_user(self, user_data: dict[str, Any]) -> None:
|
async def create_user(self, user_data: dict[str, Any]) -> None:
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
await db.execute(
|
await db.execute(
|
||||||
"INSERT INTO users (uuid, username, password_hash, role) VALUES (?, ?, ?, ?)",
|
"INSERT INTO users (uuid, username, password_hash, role, must_change_password) VALUES (?, ?, ?, ?, ?)",
|
||||||
(
|
(
|
||||||
user_data["uuid"],
|
user_data["uuid"],
|
||||||
user_data["username"],
|
user_data["username"],
|
||||||
user_data["password_hash"],
|
user_data["password_hash"],
|
||||||
user_data["role"]
|
user_data["role"],
|
||||||
|
user_data.get("must_change_password", False)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None:
|
||||||
|
async with aiosqlite.connect(self.db_path) as db:
|
||||||
|
await db.execute(
|
||||||
|
"UPDATE users SET password_hash = ?, must_change_password = ? WHERE uuid = ?",
|
||||||
|
(password_hash, must_change_password, uuid)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ def test_login_success() -> None:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
assert "access_token" in data
|
assert "access_token" in data
|
||||||
assert data["token_type"] == "bearer"
|
assert data["token_type"] == "bearer"
|
||||||
|
assert "must_change_password" in data
|
||||||
|
assert data["must_change_password"] is True
|
||||||
|
|
||||||
|
|
||||||
def test_login_failure() -> None:
|
def test_login_failure() -> None:
|
||||||
@@ -49,6 +51,38 @@ def test_login_failure() -> None:
|
|||||||
assert response.status_code == 401
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_password() -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# First login to get token
|
||||||
|
login_resp = client.post("/api/v1/auth/login", json={"username": "admin", "password": "admin"})
|
||||||
|
token = login_resp.json()["access_token"]
|
||||||
|
|
||||||
|
# Try changing password with wrong old password
|
||||||
|
resp1 = client.post(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
json={"old_password": "wrong", "new_password": "new_secure_password"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp1.status_code == 401
|
||||||
|
|
||||||
|
# Change password successfully
|
||||||
|
resp2 = client.post(
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
json={"old_password": "admin", "new_password": "new_secure_password"},
|
||||||
|
headers={"Authorization": f"Bearer {token}"}
|
||||||
|
)
|
||||||
|
assert resp2.status_code == 200
|
||||||
|
|
||||||
|
# Verify old password no longer works
|
||||||
|
resp3 = client.post("/api/v1/auth/login", json={"username": "admin", "password": "admin"})
|
||||||
|
assert resp3.status_code == 401
|
||||||
|
|
||||||
|
# Verify new password works and must_change_password is False
|
||||||
|
resp4 = client.post("/api/v1/auth/login", json={"username": "admin", "password": "new_secure_password"})
|
||||||
|
assert resp4.status_code == 200
|
||||||
|
assert resp4.json()["must_change_password"] is False
|
||||||
|
|
||||||
|
|
||||||
def test_get_logs_unauthorized() -> None:
|
def test_get_logs_unauthorized() -> None:
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
response = client.get("/api/v1/logs")
|
response = client.get("/api/v1/logs")
|
||||||
|
|||||||
Reference in New Issue
Block a user