Skip to content

Commit

Permalink
Merge pull request #1650 from st1020/feat/change-auth-to-use-projects
Browse files Browse the repository at this point in the history
feat: change auth to use projects
  • Loading branch information
Bidaya0 authored Jul 25, 2023
2 parents a51aded + 46cf9d3 commit a507762
Show file tree
Hide file tree
Showing 31 changed files with 149 additions and 177 deletions.
28 changes: 14 additions & 14 deletions dongtai_common/endpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging
from functools import reduce
from operator import ior
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

from django.core.paginator import EmptyPage, Paginator
from django.db.models import Count, Q, QuerySet
from django.db.models import Count, Q
from django.http import JsonResponse
from django.http.request import HttpRequest
from django.utils.translation import gettext_lazy as _
Expand All @@ -23,8 +23,8 @@
from dongtai_common.models.asset import Asset
from dongtai_common.models.asset_aggr import AssetAggr
from dongtai_common.models.asset_vul import IastVulAssetRelation
from dongtai_common.models.department import Department
from dongtai_common.models.log import IastLog, OperateType
from dongtai_common.models.project import IastProject
from dongtai_common.permissions import (
UserPermission,
)
Expand All @@ -33,6 +33,7 @@

if TYPE_CHECKING:
from django.core.paginator import _SupportsPagination
from django.db.models.query import QuerySet, ValuesQuerySet

logger = logging.getLogger("dongtai-core")

Expand Down Expand Up @@ -188,8 +189,8 @@ def parse_args(self, request):

@staticmethod
def get_paginator(
queryset: QuerySet, page: int = 1, page_size: int = 20
) -> tuple[dict, Union[QuerySet, "_SupportsPagination"]]:
queryset: "QuerySet | ValuesQuerySet", page: int = 1, page_size: int = 20
) -> tuple[dict, "QuerySet | _SupportsPagination"]:
"""
根据模型集合、页号、每页大小获取分页数据
:param queryset:
Expand Down Expand Up @@ -257,11 +258,10 @@ def get_auth_agents(users):
:param users:
:return:
"""
qs = Department.objects.none()
qss = [user.get_relative_department() for user in users]
departments = reduce(ior, qss, qs)
return IastAgent.objects.filter(bind_project__department__in=departments)
# if isinstance(users, QuerySet):
qs = IastProject.objects.none()
qss = [user.get_projects() for user in users]
projects = reduce(ior, qss, qs)
return IastAgent.objects.filter(bind_project__in=projects)

@staticmethod
def get_auth_assets(users):
Expand All @@ -270,10 +270,10 @@ def get_auth_assets(users):
:param users:
:return:
"""
qs = Department.objects.none()
qss = [user.get_relative_department() for user in users]
departments = reduce(ior, qss, qs)
return Asset.objects.filter(department__in=departments, is_del=0)
qs = IastProject.objects.none()
qss = [user.get_projects() for user in users]
projects = reduce(ior, qss, qs)
return Asset.objects.filter(project__in=projects, is_del=0)

@staticmethod
def get_auth_asset_aggrs(auth_assets):
Expand Down
9 changes: 9 additions & 0 deletions dongtai_common/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from django.utils.translation import gettext_lazy as _

from dongtai_common.models.department import Department
from dongtai_conf.patch import patch_point


class PermissionsMixin(models.Model):
Expand Down Expand Up @@ -109,3 +110,11 @@ def get_using_department(self):
if self.using_department:
return self.using_department
return self.get_department()

def get_projects(self) -> QuerySet:
from dongtai_common.models.project import IastProject

queryset = IastProject.objects.none()
if self.is_system_admin:
return IastProject.objects.all()
return patch_point(queryset)
17 changes: 8 additions & 9 deletions dongtai_web/aggr_vul/aggr_vul_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class GetAggregationVulList(UserEndPoint):
description = _("New application")

@extend_schema_with_envcheck(
deprecated=True,
request=AggregationArgsSerializer,
tags=[_("漏洞")],
summary=_("组件漏洞列表"),
Expand Down Expand Up @@ -155,11 +156,9 @@ def post(self, request):

except ValidationError as e:
return R.failure(data=e.detail)
departments = list(request.user.get_relative_department())
department_filter_sql = " and {}.department_id in ({})".format(
"asset", ",".join(str(x.id) for x in departments)
)
query_condition = query_condition + department_filter_sql
projects = list(request.user.get_projects())
project_filter_sql = " and {}.project_id in ({})".format("asset", ",".join(str(x.id) for x in projects))
query_condition = query_condition + project_filter_sql

if keywords:
query_base = (
Expand Down Expand Up @@ -245,7 +244,7 @@ def post(self, request):
Asset.objects.filter(
iastvulassetrelation__asset_vul_id__in=vul_ids,
iastvulassetrelation__is_del=0,
department__in=departments,
project__in=projects,
project_id__gt=0,
)
.values("project_id", "iastvulassetrelation__asset_vul_id")
Expand Down Expand Up @@ -337,10 +336,10 @@ def get_vul_list_from_elastic_search(
auth_user_info = auth_user_list_str(user_id=user_id)
auth_user_info["user_list"]
user = User.objects.filter(pk=user_id).first()
departments = user.get_relative_department()
department_ids = [department.id for department in departments]
projects = user.get_projects()
project_ids = [project.id for project in projects]
must_query = [
Q("terms", asset_department_id=department_ids),
Q("terms", asset_project_id=project_ids),
Q("terms", asset_vul_relation_is_del=[0]),
Q("range", asset_project_id={"gt": 0}),
]
Expand Down
12 changes: 6 additions & 6 deletions dongtai_web/aggr_vul/aggr_vul_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def get_annotate_sca_base_data(user_id: int, pro_condition: str):
"project": [],
}
user = User.objects.get(pk=user_id)
departments = list(user.get_relative_department())
department_filter_sql = " and {}.department_id in ({})".format("asset", ",".join(str(x.id) for x in departments))
query_condition = " where rel.is_del=0 and asset.project_id>0 " + department_filter_sql + pro_condition
projects = list(user.get_projects())
project_filter_sql = " and {}.project_id in ({})".format("asset", ",".join(str(x.id) for x in projects))
query_condition = " where rel.is_del=0 and asset.project_id>0 " + project_filter_sql + pro_condition
base_join = (
"left JOIN iast_asset_vul_relation as rel on rel.asset_vul_id=vul.id "
"left JOIN iast_asset as asset on rel.asset_id=asset.id "
Expand Down Expand Up @@ -201,10 +201,10 @@ def get_annotate_data_es(user_id, bind_project_id=None, project_version_id=None)
from dongtai_web.utils import dict_transfrom

user = User.objects.get(pk=user_id)
departments = list(user.get_relative_department())
department_ids = [i.id for i in departments]
projects = list(user.get_projects())
project_ids = [i.id for i in projects]
must_query = [
Q("terms", asset_department_id=department_ids),
Q("terms", asset_project_id=project_ids),
Q("terms", asset_vul_relation_is_del=[0]),
Q("range", asset_project_id={"gt": 0}),
]
Expand Down
16 changes: 7 additions & 9 deletions dongtai_web/aggr_vul/app_vul_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,8 @@ def post(self, request):
}
ser = AggregationArgsSerializer(data=request.data)
# 获取用户权限
departments = request.user.get_relative_department()
queryset = IastVulnerabilityModel.objects.filter(
is_del=0, project_id__gt=0, project__department__in=departments
)
projects = request.user.get_projects()
queryset = IastVulnerabilityModel.objects.filter(is_del=0, project_id__gt=0, project__in=projects)

try:
if ser.is_valid(True):
Expand Down Expand Up @@ -148,7 +146,7 @@ def post(self, request):
order_list.append(order_type_desc + order_type)
es_query["order"] = order_type_desc + order_type
if ELASTICSEARCH_STATE:
vul_data = get_vul_list_from_elastic_search(departments, page=page, page_size=page_size, **es_query)
vul_data = get_vul_list_from_elastic_search(projects, page=page, page_size=page_size, **es_query)
else:
vul_data = queryset.values(*tuple(fields)).order_by(*tuple(order_list))[begin_num:end_num]
except ValidationError as e:
Expand Down Expand Up @@ -214,7 +212,7 @@ def set_vul_inetration(end: dict[str, Any], user_id: int) -> None:


def get_vul_list_from_elastic_search(
departments,
projects,
project_ids=None,
project_version_ids=None,
hook_type_ids=None,
Expand Down Expand Up @@ -246,9 +244,9 @@ def get_vul_list_from_elastic_search(

from dongtai_common.models.strategy import IastStrategyModel

department_ids = list(departments.values_list("id", flat=True))
auth_project_ids = list(project_ids.values_list("id", flat=True))
must_query = [
Q("terms", department_id=department_ids),
Q("terms", bind_project_id=auth_project_ids),
Q("terms", is_del=[0]),
Q("range", bind_project_id={"gt": 0}),
Q("range", strategy_id={"gt": 0}),
Expand Down Expand Up @@ -290,7 +288,7 @@ def get_vul_list_from_elastic_search(
a = Q("bool", must=must_query)
hashkey = make_hash(
[
department_ids,
auth_project_ids,
project_ids,
project_version_ids,
hook_type_ids,
Expand Down
29 changes: 12 additions & 17 deletions dongtai_web/aggr_vul/app_vul_summary.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import logging

from django.db.models import Count, Q
from django.db.models.query import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework.serializers import ValidationError

from dongtai_common.endpoint import R, UserEndPoint
from dongtai_common.models.department import Department
from dongtai_common.models.project import IastProject
from dongtai_common.models.vulnerablity import IastVulnerabilityModel
from dongtai_common.utils.const import OPERATE_GET
from dongtai_conf.patch import patch_point
Expand All @@ -20,18 +21,12 @@ def _annotate_by_query(q, value_fields, count_field):
return IastVulnerabilityModel.objects.filter(q).values(*value_fields).annotate(count=Count(count_field))


# @cached_decorator(random_range=(2 * 60 * 60, 2 * 60 * 60),
# use_celery_update=True)
def get_annotate_cache_data(projects: QuerySet[IastProject]):
return get_annotate_data(projects, 0, 0)


def get_annotate_cache_data(department: Department):
return get_annotate_data(department, 0, 0)


def get_annotate_data(department: Department, bind_project_id=int, project_version_id=int) -> dict:
# cache_q = Q(is_del=0, agent__bind_project_id__gt=0,
# agent__user_id__in=auth_user_info['user_list'])
cache_q = Q(is_del=0, project_id__gt=0, project__department__in=department)
def get_annotate_data(projects: QuerySet[IastProject], bind_project_id: int, project_version_id: int) -> dict:
cache_q = Q(is_del=0, project_id__gt=0, project__in=projects)

# 从项目列表进入 绑定项目id
if bind_project_id:
Expand Down Expand Up @@ -98,7 +93,7 @@ def post(self, request):
:return:
"""

department = request.user.get_relative_department()
projects = request.user.get_projects()

ser = AggregationArgsSerializer(data=request.data)
bind_project_id = 0
Expand All @@ -111,12 +106,12 @@ def post(self, request):
project_version_id = ser.validated_data.get("project_version_id", 0)

if ELASTICSEARCH_STATE:
result_summary = get_annotate_data_es(department, bind_project_id, project_version_id)
result_summary = get_annotate_data_es(projects, bind_project_id, project_version_id)
elif bind_project_id or project_version_id:
result_summary = get_annotate_data(department, bind_project_id, project_version_id)
result_summary = get_annotate_data(projects, bind_project_id, project_version_id)
else:
# 全局下走缓存
result_summary = get_annotate_cache_data(department)
result_summary = get_annotate_cache_data(projects)
except ValidationError as e:
logger.info(e)
return R.failure(data=e.detail)
Expand All @@ -128,7 +123,7 @@ def post(self, request):
)


def get_annotate_data_es(department: Department, bind_project_id, project_version_id):
def get_annotate_data_es(projects: QuerySet[IastProject], bind_project_id: int, project_version_id: int):
from elasticsearch import Elasticsearch
from elasticsearch_dsl import A, Q

Expand All @@ -142,7 +137,7 @@ def get_annotate_data_es(department: Department, bind_project_id, project_versio
strategy_ids = list(IastStrategyModel.objects.all().values_list("id", flat=True))

must_query = [
Q("terms", department_id=list(department.values_list("id", flat=True))),
Q("terms", bind_project_id=list(projects.values_list("id", flat=True))),
Q("terms", is_del=[0]),
Q("terms", is_del=[0]),
Q("range", bind_project_id={"gt": 0}),
Expand Down
16 changes: 5 additions & 11 deletions dongtai_web/aggregation/aggregation_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,25 @@ def post(self, request):
ids = request.data.get("ids", "")
ids = turnIntListOfStr(ids)
source_type = request.data.get("source_type", 1)
department = request.user.get_relative_department()
projects = request.user.get_projects()
if source_type == 1:
queryset = IastVulnerabilityModel.objects.filter(is_del=0)
else:
queryset = IastVulAssetRelation.objects.filter(is_del=0)

# 部门删除逻辑
# 项目删除逻辑
if source_type == 1:
queryset = queryset.filter(project__department__in=department)
queryset = queryset.filter(project__in=projects)
else:
queryset = queryset.filter(asset__department__in=department)
queryset = queryset.filter(asset__project__in=projects)

if source_type == 1: # noqa: SIM108
# 应用漏洞删除
del_queryset = queryset.filter(id__in=ids)
else:
# 组件漏洞删除
del_queryset = queryset.filter(asset_vul_id__in=ids)
# with connection.cursor() as cursor:
# sca_ids_str)
for vul in del_queryset:
vul.is_del = 1
vul.save()
return R.success(
data={
"messages": "success",
},
)
return R.success(data={"messages": "success"})
12 changes: 4 additions & 8 deletions dongtai_web/aggregation/aggregation_project_del.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def post(self, request):
return R.failure()
project_version_id = request.data.get("project_version_id", None)
source_type = request.data.get("source_type", 1)
department = request.user.get_relative_department()
project = request.user.get_projects()
if source_type == 1:
queryset = IastVulnerabilityModel.objects.filter(is_del=0)
else:
Expand All @@ -44,15 +44,11 @@ def post(self, request):

# 部门删除逻辑
if source_type == 1:
queryset = queryset.filter(project__department__in=department)
queryset = queryset.filter(project__in=project)
else:
queryset = queryset.filter(asset__department__in=department)
queryset = queryset.filter(asset__project__in=project)

for vul in queryset:
vul.is_del = 1
vul.save()
return R.success(
data={
"messages": "success",
},
)
return R.success(data={"messages": "success"})
5 changes: 3 additions & 2 deletions dongtai_web/base/project_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.db import transaction
from django.db.models import Q
from django.db.models.query import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers

Expand All @@ -20,13 +21,13 @@ class VersionModifySerializer(serializers.Serializer):


@transaction.atomic
def version_modify(user, department, versionData=None):
def version_modify(user, projects: QuerySet[IastProject], versionData):
version_id = versionData.get("version_id", 0)
project_id = versionData.get("project_id", 0)
current_version = versionData.get("current_version", 0)
version_name = versionData.get("version_name", "")
description = versionData.get("description", "")
project = IastProject.objects.filter(department__in=department, id=project_id).only("id", "user").first()
project = projects.filter(id=project_id).only("id", "user").first()
if not version_name or not project:
return {"status": "202", "msg": _("Parameter error")}
baseVersion = IastProjectVersion.objects.filter(
Expand Down
3 changes: 1 addition & 2 deletions dongtai_web/dongtai_sca/views/newpackageprojects.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def get(self, request, language_id, package_name, package_version):
pass
except ValidationError as e:
return R.failure(data=e.detail)
departments = request.user.get_relative_department()
queryset = IastProject.objects.filter(department__in=departments).order_by("-latest_time")
queryset = request.user.get_projects().order_by("-latest_time")
assets_project_ids = (
AssetV2.objects.filter(
language_id=language_id,
Expand Down
Loading

0 comments on commit a507762

Please sign in to comment.