Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/optimise auth use cache #1455

Merged
merged 4 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions dongtai_common/common/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,29 @@ def _noname(function):
return _noname


@cached_decorator(random_range=(60, 120), use_celery_update=False)
def get_user_from_department_key(key):
from dongtai_common.models.department import Department
from dongtai_common.models.user import User
from rest_framework import exceptions
department = Department.objects.get(token=key)
principal = User.objects.filter(pk=department.principal_id).first()
user = principal if principal else User.objects.filter(pk=1).first()
user.using_department = department
return user

class DepartmentTokenAuthentication(TokenAuthentication):

keyword = 'Token GROUP'
model = None

def authenticate_credentials(self, key):
def auth_decodedenticate_credentials(self, key):
from dongtai_common.models.department import Department
from dongtai_common.models.user import User
from rest_framework import exceptions
model = Department
try:
department = model.objects.get(token=key)
principal = User.objects.filter(pk=department.principal_id).first()
user = principal if principal else User.objects.filter(pk=1).first()
user.using_department = department
except model.DoesNotExist:
user = get_user_from_department_key(key)
except Department.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.'))
return (user, key)

Expand All @@ -125,4 +132,4 @@ def authenticate(self, request):
return None
token = auth.lower().replace(self.keyword.lower().encode(), b'',
1).decode()
return self.authenticate_credentials(token)
return self.auth_decodedenticate_credentials(token)
28 changes: 27 additions & 1 deletion dongtai_protocol/report/handler/report_handler_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@
from django.db.models import Q
from dongtai_common.models.agent import IastAgent
from django.utils.translation import gettext_lazy as _
from dongtai_common.common.utils import cached_decorator

logger = logging.getLogger('dongtai.openapi')


@cached_decorator(random_range=(60, 120), use_celery_update=False)
def get_agent(agent_id, kwargs, fields):
return IastAgent.objects.filter(
id=agent_id,
**kwargs,
).only(*fields).first()

class IReportHandler:
def __init__(self):
self._report = None
Expand Down Expand Up @@ -102,4 +110,22 @@ def get_project_agents(self, agent):
return agents

def get_agent(self, agent_id):
return IastAgent.objects.filter(id=agent_id, online=1, user=self.user_id).first()
return get_agent(
agent_id,
{
"pk": agent_id,
"online": 1,
"user": self.user_id,
},
(
'id',
'bind_project_id',
'project_version_id',
'project_name',
'language',
'project_version_id',
'server_id',
'filepathsimhash',
'servicetype',
),
)
127 changes: 73 additions & 54 deletions dongtai_protocol/report/handler/saas_method_pool_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from dongtai_engine.tasks import search_vul_from_method_pool, search_vul_from_replay_method_pool
from dongtai_conf import settings
from dongtai_protocol import utils
from dongtai_protocol.report.handler.report_handler_interface import IReportHandler
from dongtai_protocol.report.handler.report_handler_interface import IReportHandler, get_agent
from dongtai_protocol.report.report_handler_factory import ReportHandler
import gzip
import base64
Expand Down Expand Up @@ -183,15 +183,14 @@ def save(self):
logger.warning(e, exc_info=True)
else:
current_version_agents = self.get_project_agents(self.agent)
with transaction.atomic():
try:
update_record, method_pool = self.save_method_call(
pool_sign, current_version_agents)
except Exception as e:
logger.info(
f"record method failed : {self.agent_id} {self.http_uri} {self.http_method}"
)
logger.warning(e, exc_info=e)
try:
update_record, method_pool = self.save_method_call(
pool_sign, current_version_agents)
except Exception as e:
logger.info(
f"record method failed : {self.agent_id} {self.http_uri} {self.http_method}"
)
logger.warning(e, exc_info=e)
try:
logger.info(f"send normal method pool {self.agent_id} {self.http_uri} {pool_sign} to celery ")
self.send_to_engine(method_pool_sign=pool_sign,
Expand Down Expand Up @@ -267,46 +266,47 @@ def save_method_call(self, pool_sign: str,
"""
# todo need to del
# pool_sign = random.sample('zyxwvutsrqmlkjihgfedcba',5)
method_pool = MethodPool.objects.filter(
pool_sign=pool_sign, agent__in=current_version_agents).first()
update_record = True
if method_pool:
method_pool.update_time = int(time.time())
method_pool.method_pool = json.dumps(self.method_pool)
method_pool.uri = self.http_uri
method_pool.url = self.http_url
method_pool.http_method = self.http_method
method_pool.req_header = self.http_req_header
method_pool.req_params = self.http_query_string
method_pool.req_data = self.http_req_data
method_pool.req_header_fs = utils.build_request_header(
req_method=self.http_method,
raw_req_header=self.http_req_header,
uri=self.http_uri,
query_params=self.http_query_string,
http_protocol=self.http_protocol)
method_pool.res_header = utils.base64_decode(self.http_res_header)
method_pool.res_body = new_decode_content(
self.http_res_body, get_content_encoding(self.http_res_header),
self.version)
method_pool.uri_sha1 = self.sha1(self.http_uri)
method_pool.save(update_fields=[
'update_time',
'method_pool',
'uri',
'url',
'http_method',
'req_header',
'req_params',
'req_data',
'req_header_fs',
'res_header',
'res_body',
'uri_sha1',
])
else:
# 获取agent
update_record = False
# method_pool = MethodPool.objects.filter(
# pool_sign=pool_sign, agent__in=current_version_agents).first()
# update_record = True
# if method_pool:
# method_pool.update_time = int(time.time())
# method_pool.method_pool = json.dumps(self.method_pool)
# method_pool.uri = self.http_uri
# method_pool.url = self.http_url
# method_pool.http_method = self.http_method
# method_pool.req_header = self.http_req_header
# method_pool.req_params = self.http_query_string
# method_pool.req_data = self.http_req_data
# method_pool.req_header_fs = utils.build_request_header(
# req_method=self.http_method,
# raw_req_header=self.http_req_header,
# uri=self.http_uri,
# query_params=self.http_query_string,
# http_protocol=self.http_protocol)
# method_pool.res_header = utils.base64_decode(self.http_res_header)
# method_pool.res_body = new_decode_content(
# self.http_res_body, get_content_encoding(self.http_res_header),
# self.version)
# method_pool.uri_sha1 = self.sha1(self.http_uri)
# method_pool.save(update_fields=[
# 'update_time',
# 'method_pool',
# 'uri',
# 'url',
# 'http_method',
# 'req_header',
# 'req_params',
# 'req_data',
# 'req_header_fs',
# 'res_header',
# 'res_body',
# 'uri_sha1',
# ])
# else:
# 获取agent
update_record = False
try:
timestamp = int(time.time())
method_pool = MethodPool.objects.create(
agent=self.agent,
Expand Down Expand Up @@ -336,6 +336,9 @@ def save_method_call(self, pool_sign: str,
update_time=timestamp,
uri_sha1=self.sha1(self.http_uri),
)
except (IntegrityError, MultipleObjectsReturned) as e:
logger.info(e)
logger.debug(e, exc_info=e)
return update_record, method_pool

def send_to_engine(self, method_pool_id="", method_pool_sign="", update_record=False, model=None):
Expand Down Expand Up @@ -396,10 +399,26 @@ def sha256(raw):
return h.hexdigest()

def get_agent(self, agent_id):
return IastAgent.objects.filter(id=agent_id,
online=1,
allow_report=1,
user=self.user_id).first()
return get_agent(
agent_id,
{
"pk": agent_id,
"online": 1,
"user": self.user_id,
"allow_report": 1,
},
(
'id',
'bind_project_id',
'project_version_id',
'project_name',
'language',
'project_version_id',
'server_id',
'filepathsimhash',
'servicetype',
),
)


def save_project_header(keys: list, agent_id: int):
Expand Down