diff --git a/decnet/web/api.py b/decnet/web/api.py index 8cb4418..c64e75a 100644 --- a/decnet/web/api.py +++ b/decnet/web/api.py @@ -34,6 +34,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: "username": "admin", "password_hash": get_password_hash("admin"), "role": "admin", + "must_change_password": True } ) yield @@ -76,6 +77,7 @@ async def get_current_user(token: str = Depends(oauth2_scheme)) -> str: class Token(BaseModel): access_token: str token_type: str + must_change_password: bool = False class LoginRequest(BaseModel): @@ -83,6 +85,11 @@ class LoginRequest(BaseModel): password: str +class ChangePasswordRequest(BaseModel): + old_password: str + new_password: str + + class LogsResponse(BaseModel): total: int limit: int @@ -91,7 +98,7 @@ class LogsResponse(BaseModel): @app.post("/api/v1/auth/login", response_model=Token) -async def login(request: LoginRequest) -> dict[str, str]: +async def login(request: LoginRequest) -> dict[str, Any]: user: dict[str, Any] | None = await repo.get_user_by_username(request.username) if not user or not verify_password(request.password, user["password_hash"]): raise HTTPException( @@ -105,7 +112,25 @@ async def login(request: LoginRequest) -> dict[str, str]: access_token: str = create_access_token( data={"uuid": user["uuid"]}, expires_delta=access_token_expires ) - return {"access_token": access_token, "token_type": "bearer"} + return { + "access_token": access_token, + "token_type": "bearer", + "must_change_password": bool(user.get("must_change_password", False)) + } + + +@app.post("/api/v1/auth/change-password") +async def change_password(request: ChangePasswordRequest, current_user: str = Depends(get_current_user)) -> dict[str, str]: + user: dict[str, Any] | None = await repo.get_user_by_uuid(current_user) + if not user or not verify_password(request.old_password, user["password_hash"]): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect old password", + ) + + new_hash = get_password_hash(request.new_password) + await repo.update_user_password(current_user, new_hash, must_change_password=False) + return {"message": "Password updated successfully"} @app.get("/api/v1/logs", response_model=LogsResponse) diff --git a/decnet/web/repository.py b/decnet/web/repository.py index c8db500..ec07f53 100644 --- a/decnet/web/repository.py +++ b/decnet/web/repository.py @@ -40,7 +40,17 @@ class BaseRepository(ABC): """Retrieve a user by their username.""" pass + @abstractmethod + async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: + """Retrieve a user by their UUID.""" + pass + @abstractmethod async def create_user(self, user_data: dict[str, Any]) -> None: """Create a new dashboard user.""" pass + + @abstractmethod + async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None: + """Update a user's password and change the must_change_password flag.""" + pass diff --git a/decnet/web/sqlite_repository.py b/decnet/web/sqlite_repository.py index 7982993..aa2b968 100644 --- a/decnet/web/sqlite_repository.py +++ b/decnet/web/sqlite_repository.py @@ -29,9 +29,14 @@ class SQLiteRepository(BaseRepository): uuid TEXT PRIMARY KEY, username TEXT UNIQUE, password_hash TEXT, - role TEXT DEFAULT 'viewer' + role TEXT DEFAULT 'viewer', + must_change_password BOOLEAN DEFAULT 0 ) """) + try: + await db.execute("ALTER TABLE users ADD COLUMN must_change_password BOOLEAN DEFAULT 0") + except aiosqlite.OperationalError: + pass # Column already exists await db.commit() async def add_log(self, log_data: dict[str, Any]) -> None: @@ -112,15 +117,31 @@ class SQLiteRepository(BaseRepository): row = await cursor.fetchone() return dict(row) if row else None + async def get_user_by_uuid(self, uuid: str) -> Optional[dict[str, Any]]: + async with aiosqlite.connect(self.db_path) as db: + db.row_factory = aiosqlite.Row + async with db.execute("SELECT * FROM users WHERE uuid = ?", (uuid,)) as cursor: + row = await cursor.fetchone() + return dict(row) if row else None + async def create_user(self, user_data: dict[str, Any]) -> None: async with aiosqlite.connect(self.db_path) as db: await db.execute( - "INSERT INTO users (uuid, username, password_hash, role) VALUES (?, ?, ?, ?)", + "INSERT INTO users (uuid, username, password_hash, role, must_change_password) VALUES (?, ?, ?, ?, ?)", ( user_data["uuid"], user_data["username"], user_data["password_hash"], - user_data["role"] + user_data["role"], + user_data.get("must_change_password", False) ) ) await db.commit() + + async def update_user_password(self, uuid: str, password_hash: str, must_change_password: bool = False) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "UPDATE users SET password_hash = ?, must_change_password = ? WHERE uuid = ?", + (password_hash, must_change_password, uuid) + ) + await db.commit() diff --git a/tests/test_web_api.py b/tests/test_web_api.py index 4d81e9e..7e05596 100644 --- a/tests/test_web_api.py +++ b/tests/test_web_api.py @@ -32,6 +32,8 @@ def test_login_success() -> None: data = response.json() assert "access_token" in data assert data["token_type"] == "bearer" + assert "must_change_password" in data + assert data["must_change_password"] is True def test_login_failure() -> None: @@ -49,6 +51,38 @@ def test_login_failure() -> None: assert response.status_code == 401 +def test_change_password() -> None: + with TestClient(app) as client: + # First login to get token + login_resp = client.post("/api/v1/auth/login", json={"username": "admin", "password": "admin"}) + token = login_resp.json()["access_token"] + + # Try changing password with wrong old password + resp1 = client.post( + "/api/v1/auth/change-password", + json={"old_password": "wrong", "new_password": "new_secure_password"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert resp1.status_code == 401 + + # Change password successfully + resp2 = client.post( + "/api/v1/auth/change-password", + json={"old_password": "admin", "new_password": "new_secure_password"}, + headers={"Authorization": f"Bearer {token}"} + ) + assert resp2.status_code == 200 + + # Verify old password no longer works + resp3 = client.post("/api/v1/auth/login", json={"username": "admin", "password": "admin"}) + assert resp3.status_code == 401 + + # Verify new password works and must_change_password is False + resp4 = client.post("/api/v1/auth/login", json={"username": "admin", "password": "new_secure_password"}) + assert resp4.status_code == 200 + assert resp4.json()["must_change_password"] is False + + def test_get_logs_unauthorized() -> None: with TestClient(app) as client: response = client.get("/api/v1/logs")