mirror of
https://github.com/stardrophere/InsightRadar.git
synced 2026-06-05 23:56:36 +08:00
修改只有一次验证机会的bug
This commit is contained in:
@@ -29,7 +29,12 @@ class MemoryRepository(VerificationRepository):
|
|||||||
with self._lock:
|
with self._lock:
|
||||||
self._store[key] = (json.dumps(payload), expire_at)
|
self._store[key] = (json.dumps(payload), expire_at)
|
||||||
|
|
||||||
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]:
|
def compare_and_consume(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
purpose: VerificationPurpose,
|
||||||
|
code_hash: str,
|
||||||
|
) -> Optional[bool]:
|
||||||
key = self._key(email, purpose)
|
key = self._key(email, purpose)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@@ -44,14 +49,18 @@ class MemoryRepository(VerificationRepository):
|
|||||||
del self._store[key]
|
del self._store[key]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
del self._store[key]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
return payload.get("code_hash")
|
stored_hash = payload.get("code_hash")
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
if stored_hash == code_hash:
|
||||||
|
del self._store[key]
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def incr(self, key: str, ttl: int) -> int:
|
def incr(self, key: str, ttl: int) -> int:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -2,83 +2,90 @@ from functools import lru_cache
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import datetime
|
|
||||||
import redis
|
import redis
|
||||||
from typing import Optional, TYPE_CHECKING
|
from typing import Optional
|
||||||
|
|
||||||
from app.models.models import VerificationPurpose
|
from app.models.models import VerificationPurpose
|
||||||
from app.core.verification.email.verificationRepository import VerificationRepository
|
from app.core.verification.email.verificationRepository import VerificationRepository
|
||||||
from app.utils.redis_client import get_redis_client
|
from app.utils.redis_client import get_redis_client
|
||||||
from app.core.security import hash_verification_code
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
AUTH_CODE_REDIS_PREFIX = os.getenv("AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code").strip()
|
AUTH_CODE_REDIS_PREFIX = os.getenv(
|
||||||
|
"AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code"
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
|
||||||
class RedisRepository(VerificationRepository):
|
class RedisRepository(VerificationRepository):
|
||||||
_consume_lua = """ local val = redis.call("GET", KEYS[1]) if val then redis.call("DEL", KEYS[1]) end return val """
|
|
||||||
|
|
||||||
|
_compare_and_consume_lua = """
|
||||||
|
local val = redis.call("GET", KEYS[1])
|
||||||
|
if not val then
|
||||||
|
return nil
|
||||||
|
end
|
||||||
|
|
||||||
|
local data = cjson.decode(val)
|
||||||
|
|
||||||
|
if data["code_hash"] == ARGV[1] then
|
||||||
|
redis.call("DEL", KEYS[1])
|
||||||
|
return 1
|
||||||
|
else
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, client: redis.Redis):
|
def __init__(self, client: redis.Redis):
|
||||||
self.client = client
|
self.client = client
|
||||||
self._consume_script = self.client.register_script(self._consume_lua)
|
self._compare_script = self.client.register_script(
|
||||||
|
self._compare_and_consume_lua
|
||||||
|
)
|
||||||
|
|
||||||
|
def _key(self, email: str, purpose: VerificationPurpose) -> str:
|
||||||
|
|
||||||
def _key(self, email, purpose):
|
|
||||||
return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:code"
|
return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:code"
|
||||||
|
|
||||||
def set_code(self, email: str, purpose: VerificationPurpose, code_hash: str, ttl: int) -> None:
|
def set_code(
|
||||||
"""store the code into the redis
|
self,
|
||||||
|
email: str,
|
||||||
Args:
|
purpose: VerificationPurpose,
|
||||||
email (str): email of user
|
code_hash: str,
|
||||||
purpose (VerificationPurpose): purpose of the code, such as "login", "register"
|
ttl: int,
|
||||||
code_hash: the hash of the code
|
) -> None:
|
||||||
ttl: duration of the code
|
|
||||||
|
|
||||||
"""
|
|
||||||
key = self._key(email, purpose)
|
key = self._key(email, purpose)
|
||||||
|
|
||||||
payload = json.dumps({
|
payload = json.dumps({
|
||||||
"code_hash": code_hash,
|
"code_hash": code_hash,
|
||||||
"exp": datetime.datetime.now().timestamp()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
self.client.set(key, payload, ex=ttl)
|
self.client.set(key, payload, ex=ttl)
|
||||||
|
|
||||||
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]:
|
def compare_and_consume(
|
||||||
"""consume the code of email
|
self,
|
||||||
|
email: str,
|
||||||
Args:
|
purpose: VerificationPurpose,
|
||||||
email (str): email of user
|
code_hash: str,
|
||||||
purpose (VerificationPurpose): purpose of the code, such as "login", "register"
|
) -> Optional[bool]:
|
||||||
|
|
||||||
Returns:
|
|
||||||
_type_: if email has a code which has not been consumed, return the code, else return None
|
|
||||||
"""
|
|
||||||
|
|
||||||
key = self._key(email, purpose)
|
key = self._key(email, purpose)
|
||||||
data = self._consume_script(keys=[key])
|
|
||||||
|
|
||||||
if not data:
|
result = self._compare_script(
|
||||||
|
keys=[key],
|
||||||
|
args=[code_hash],
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
if result == 1:
|
||||||
payload = json.loads(data) # type: ignore
|
return True
|
||||||
return payload.get("code_hash")
|
|
||||||
except Exception:
|
return False
|
||||||
return None
|
|
||||||
|
|
||||||
def incr(self, key: str, ttl: int) -> int:
|
def incr(self, key: str, ttl: int) -> int:
|
||||||
super().incr(key, ttl)
|
|
||||||
value = self.client.incr(key)
|
value = self.client.incr(key)
|
||||||
|
|
||||||
if value == 1:
|
if value == 1:
|
||||||
self.client.expire(key, ttl)
|
self.client.expire(key, ttl)
|
||||||
|
|
||||||
return value # type: ignore
|
return int(value) # type: ignore
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def get_redis_repo():
|
def get_redis_repo():
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# app/verification/backends/hybrid.py
|
|
||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import logging
|
import logging
|
||||||
@@ -28,14 +27,23 @@ class HybridRepository(VerificationRepository):
|
|||||||
|
|
||||||
self.memory.set_code(email, purpose, code_hash, ttl)
|
self.memory.set_code(email, purpose, code_hash, ttl)
|
||||||
|
|
||||||
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]:
|
def compare_and_consume(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
purpose: VerificationPurpose,
|
||||||
|
code_hash: str,
|
||||||
|
) -> Optional[bool]:
|
||||||
if self.redis:
|
if self.redis:
|
||||||
try:
|
try:
|
||||||
return self.redis.consume_code(email, purpose)
|
return self.redis.compare_and_consume(
|
||||||
|
email, purpose, code_hash
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Redis consume_code failed, fallback to memory: %s", e)
|
logger.warning(
|
||||||
|
"Redis compare_and_consume failed, fallback to memory: %s", e
|
||||||
|
)
|
||||||
|
|
||||||
return self.memory.consume_code(email, purpose)
|
return self.memory.compare_and_consume(email, purpose, code_hash)
|
||||||
|
|
||||||
def incr(self, key: str, ttl: int) -> int:
|
def incr(self, key: str, ttl: int) -> int:
|
||||||
if self.redis:
|
if self.redis:
|
||||||
|
|||||||
@@ -20,7 +20,12 @@ class VerificationRepository(ABC):
|
|||||||
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def consume_code(self, email: str, purpose: VerificationPurpose) -> Optional[str]:
|
def compare_and_consume(
|
||||||
|
self,
|
||||||
|
email: str,
|
||||||
|
purpose: VerificationPurpose,
|
||||||
|
code_hash: str,
|
||||||
|
) -> Optional[bool]:
|
||||||
"""consume the code atomically
|
"""consume the code atomically
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -28,7 +33,7 @@ class VerificationRepository(ABC):
|
|||||||
purpose (VerificationPurpose): the purpose of the code, such as, "login", "register"
|
purpose (VerificationPurpose): the purpose of the code, such as, "login", "register"
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: if success return the code, else return None
|
Optional[str]: if success return the true
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -62,15 +62,16 @@ class EmailVerificationService:
|
|||||||
|
|
||||||
def verify_code(self,email: str, code: str, purpose: VerificationPurpose):
|
def verify_code(self,email: str, code: str, purpose: VerificationPurpose):
|
||||||
email = email.lower()
|
email = email.lower()
|
||||||
|
code_hash = hash_verification_code(code)
|
||||||
|
|
||||||
stored_hash: Optional[str] = self.repo.consume_code(email, purpose)
|
stored = self.repo.compare_and_consume(email, purpose, code_hash)
|
||||||
|
|
||||||
if not stored_hash:
|
if stored == False:
|
||||||
raise CodeExpiredError("Code expired or not found")
|
|
||||||
|
|
||||||
if stored_hash != hash_verification_code(code):
|
|
||||||
raise CodeInvalidError("Invalid code")
|
raise CodeInvalidError("Invalid code")
|
||||||
|
|
||||||
|
if not stored:
|
||||||
|
raise CodeExpiredError("Code expired or not found")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
|||||||
Reference in New Issue
Block a user