diff --git a/.env b/.env deleted file mode 100644 index 4f2da5e..0000000 --- a/.env +++ /dev/null @@ -1,52 +0,0 @@ -# Django Configuration -SECRET_KEY=django-insecure-0a1bx*8!97l^4z#ml#ufn_*9ut*)zlso$*k-g^h&(2=p@^51md -DEBUG=True -#ALLOWED_HOSTS=localhost,127.0.0.1,django-host,192.168.1.22 -ALLOWED_HOSTS=* -# 切换数据源,支持sqlite/mysql -DB_ENGINE=sqlite -USE_REDIS_AS_DB=True - -# MySQL Configuration (when USE_REDIS_AS_DB=False) -DB_NAME=hertz_server -DB_USER=root -DB_PASSWORD=123456 -DB_HOST=localhost -DB_PORT=3306 - -# Redis Configuration -REDIS_URL=redis://127.0.0.1:6379/0 - -# CORS Configuration -CORS_ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000 -CORS_ALLOW_ALL_ORIGINS=True - - -# Email Configuration -EMAIL_BACKEND=django.core.mail.backends.smtp.EmailBackend -EMAIL_HOST=smtp.qq.com -EMAIL_PORT=465 -EMAIL_USE_SSL=True -EMAIL_USE_TLS=False -EMAIL_HOST_USER=your_email@qq.com -EMAIL_HOST_PASSWORD=your_email_password -DEFAULT_FROM_EMAIL=your_email@qq.com - -# 注册邮箱验证码开关(0=关闭,1=开启) -REGISTER_EMAIL_VERIFICATION=0 - -# Hertz Captcha Configuration -HERTZ_CAPTCHA_LENGTH=4 -HERTZ_CAPTCHA_WIDTH=120 -HERTZ_CAPTCHA_HEIGHT=50 -HERTZ_CAPTCHA_FONT_SIZE=30 -HERTZ_CAPTCHA_TIMEOUT=300 -HERTZ_CAPTCHA_BACKGROUND_COLOR=#ffffff -HERTZ_CAPTCHA_FOREGROUND_COLOR=#000000 -HERTZ_CAPTCHA_NOISE_LEVEL=0.3 -HERTZ_CAPTCHA_REDIS_KEY_PREFIX=hertz_captcha: - -# Auth Middleware Configuration - 不需要登录验证的URL模式(支持正则表达式) -# 格式:使用逗号分隔的正则表达式模式 -# 示例:/api/demo 表示demo接口,/api/.* 表示/api路径下的所有 -NO_AUTH_PATTERNS=^/api/auth/login/?$,^/api/auth/register/?$,^/api/auth/email/code/?$,^/api/auth/send-email-code/?$,^/api/auth/password/reset/?$,^/api/captcha/.*$,^/api/docs/.*$,^/api/redoc/.*$,^/api/schema/.*$,^/admin/.*$,^/static/.*$,^/media/.*$,^/demo/.*$,^/websocket/.*$,^/api/system/.*$,^/yolo/.*$, diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 85f7b13..0000000 --- a/.gitignore +++ /dev/null @@ -1,101 +0,0 @@ -# Python bytecode -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# python envs -venv/ - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Django -*.log -local_settings.py -media/ -staticfiles/ -static_root/ - - -# IDE / editors -.idea/ -.vscode/ -*.iml - -# Type checking -.mypy_cache/ -.pytype/ -.pyre/ - -# Celery -celerybeat-schedule -celerybeat.pid - -# Sphinx docs -docs/_build/ - -# PyInstaller -*.manifest -*.spec - -# Project-specific data/models (uploads & intermediates) -media/models/ -media/uploads/ -media/yolo/temp/ -media/yolo/models/ -media/sklearn/models/ -media/detection/temp/ -media/detection/result/ -media/detection/original/ - -# Logs -logs/ -*.log - -# OS files -.DS_Store -Thumbs.db - -# Yolo models -hertz_studio_django_utils/yolo/Train/runs/ - -# 技术支持文档 -shared/ diff --git a/data/db.sqlite3 b/data/db.sqlite3 deleted file mode 100644 index 49048b1..0000000 Binary files a/data/db.sqlite3 and /dev/null differ diff --git a/data/frontend_7z/admin_AlertLevelManagement.7z b/data/frontend_7z/admin_AlertLevelManagement.7z deleted file mode 100644 index a5e12e0..0000000 Binary files a/data/frontend_7z/admin_AlertLevelManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_AlertProcessingCenter.7z b/data/frontend_7z/admin_AlertProcessingCenter.7z deleted file mode 100644 index 77dd675..0000000 Binary files a/data/frontend_7z/admin_AlertProcessingCenter.7z and /dev/null differ diff --git a/data/frontend_7z/admin_ArticleManagement.7z b/data/frontend_7z/admin_ArticleManagement.7z deleted file mode 100644 index 59e6c68..0000000 Binary files a/data/frontend_7z/admin_ArticleManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_Dashboard.7z b/data/frontend_7z/admin_Dashboard.7z deleted file mode 100644 index 28bf330..0000000 Binary files a/data/frontend_7z/admin_Dashboard.7z and /dev/null differ diff --git a/data/frontend_7z/admin_DatasetManagement.7z b/data/frontend_7z/admin_DatasetManagement.7z deleted file mode 100644 index b3e82bc..0000000 Binary files a/data/frontend_7z/admin_DatasetManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_DepartmentManagement.7z b/data/frontend_7z/admin_DepartmentManagement.7z deleted file mode 100644 index 2c963d5..0000000 Binary files a/data/frontend_7z/admin_DepartmentManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_DetectionHistoryManagement.7z b/data/frontend_7z/admin_DetectionHistoryManagement.7z deleted file mode 100644 index 94c28bd..0000000 Binary files a/data/frontend_7z/admin_DetectionHistoryManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_LogManagement.7z b/data/frontend_7z/admin_LogManagement.7z deleted file mode 100644 index c71d6a0..0000000 Binary files a/data/frontend_7z/admin_LogManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_MenuManagement.7z b/data/frontend_7z/admin_MenuManagement.7z deleted file mode 100644 index 0f8bf2a..0000000 Binary files a/data/frontend_7z/admin_MenuManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_ModelManagement.7z b/data/frontend_7z/admin_ModelManagement.7z deleted file mode 100644 index ebb2b0f..0000000 Binary files a/data/frontend_7z/admin_ModelManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_NotificationManagement.7z b/data/frontend_7z/admin_NotificationManagement.7z deleted file mode 100644 index 297e512..0000000 Binary files a/data/frontend_7z/admin_NotificationManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_Role.7z b/data/frontend_7z/admin_Role.7z deleted file mode 100644 index 7c1b69b..0000000 Binary files a/data/frontend_7z/admin_Role.7z and /dev/null differ diff --git a/data/frontend_7z/admin_UserManagement.7z b/data/frontend_7z/admin_UserManagement.7z deleted file mode 100644 index eaf111a..0000000 Binary files a/data/frontend_7z/admin_UserManagement.7z and /dev/null differ diff --git a/data/frontend_7z/admin_YoloTrainManagement.7z b/data/frontend_7z/admin_YoloTrainManagement.7z deleted file mode 100644 index a1b9c7b..0000000 Binary files a/data/frontend_7z/admin_YoloTrainManagement.7z and /dev/null differ diff --git a/data/frontend_7z/user_AIChat.7z b/data/frontend_7z/user_AIChat.7z deleted file mode 100644 index 77a50a9..0000000 Binary files a/data/frontend_7z/user_AIChat.7z and /dev/null differ diff --git a/data/frontend_7z/user_ArticleCenter.7z b/data/frontend_7z/user_ArticleCenter.7z deleted file mode 100644 index 8d269e7..0000000 Binary files a/data/frontend_7z/user_ArticleCenter.7z and /dev/null differ diff --git a/data/frontend_7z/user_KnowledgeBase.7z b/data/frontend_7z/user_KnowledgeBase.7z deleted file mode 100644 index 61ec1d9..0000000 Binary files a/data/frontend_7z/user_KnowledgeBase.7z and /dev/null differ diff --git a/data/frontend_7z/user_LiveDetection.7z b/data/frontend_7z/user_LiveDetection.7z deleted file mode 100644 index bb7334b..0000000 Binary files a/data/frontend_7z/user_LiveDetection.7z and /dev/null differ diff --git a/data/frontend_7z/user_Messages.7z b/data/frontend_7z/user_Messages.7z deleted file mode 100644 index eebc4f2..0000000 Binary files a/data/frontend_7z/user_Messages.7z and /dev/null differ diff --git a/data/frontend_7z/user_Profile.7z b/data/frontend_7z/user_Profile.7z deleted file mode 100644 index e3c0769..0000000 Binary files a/data/frontend_7z/user_Profile.7z and /dev/null differ diff --git a/data/frontend_7z/user_YoloDetection.7z b/data/frontend_7z/user_YoloDetection.7z deleted file mode 100644 index 76e3545..0000000 Binary files a/data/frontend_7z/user_YoloDetection.7z and /dev/null differ diff --git a/docs/API接口文档/AI聊天模块接口文档.md b/docs/API接口文档/AI聊天模块接口文档.md deleted file mode 100644 index 827b851..0000000 --- a/docs/API接口文档/AI聊天模块接口文档.md +++ /dev/null @@ -1,246 +0,0 @@ -# Hertz Studio Django AI聊天模块接口文档 - -- 基础路径: `/api/ai/` -- 统一响应: 使用 `HertzResponse`,结构如下 - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由通过 `path('api/ai/', include('hertz_studio_django_ai.urls'))` 挂载(`hertz_server_django/urls.py:23`)。 -- 认证说明: 所有接口需在请求头携带 `Authorization: Bearer `(`hertz_studio_django_ai/views.py:34` 使用 `login_required`)。 - -## 接口列表 - -### 获取聊天列表 -- 方法: `GET` -- 路径: `/api/ai/chats/` -- 查询参数: - - `query` 可选,按标题模糊搜索 - - `page` 可选,默认 `1` - - `page_size` 可选,默认 `10` -- 实现: `AIChatListView.get`(`hertz_studio_django_ai/views.py:36`) -- 请求示例: - ```http - GET /api/ai/chats/?query=Python&page=1&page_size=10 - Authorization: Bearer - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "total": 25, - "page": 1, - "page_size": 10, - "chats": [ - { - "id": 1, - "title": "Python编程问题", - "created_at": "2024-01-15 10:30:00", - "updated_at": "2024-01-15 10:35:00", - "latest_message": "如何使用Django创建API接口?..." - } - ] - } - } - ``` - -### 创建聊天 -- 方法: `POST` -- 路径: `/api/ai/chats/create/` -- 请求体: `application/json` -- 字段: `title` 可选,默认 `"新对话"` -- 实现: `AIChatCreateView.post`(`hertz_studio_django_ai/views.py:100`) -- 请求示例: - ```http - POST /api/ai/chats/create/ - Authorization: Bearer - Content-Type: application/json - - { - "title": "AI编程助手" - } - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "创建成功", - "data": { - "chat_id": 3, - "title": "AI编程助手" - } - } - ``` - -### 聊天详情 -- 方法: `GET` -- 路径: `/api/ai/chats/{chat_id}/` -- 路径参数: `chat_id` 聊天ID(整数) -- 实现: `AIChatDetailView.get`(`hertz_studio_django_ai/views.py:137`) -- 请求示例: - ```http - GET /api/ai/chats/1/ - Authorization: Bearer - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "id": 1, - "title": "Python编程问题", - "created_at": "2024-01-15 10:30:00", - "updated_at": "2024-01-15 10:35:00", - "messages": [ - { - "id": 1, - "role": "user", - "content": "如何使用Django创建API接口?", - "created_at": "2024-01-15 10:30:00" - }, - { - "id": 2, - "role": "assistant", - "content": "使用Django REST Framework可以快速构建API...", - "created_at": "2024-01-15 10:30:30" - } - ] - } - } - ``` - -### 更新聊天 -- 方法: `PUT` -- 路径: `/api/ai/chats/{chat_id}/update/` -- 路径参数: `chat_id` 聊天ID(整数) -- 请求体: `application/json` -- 字段: `title` 新标题 -- 实现: `AIChatUpdateView.put`(`hertz_studio_django_ai/views.py:192`) -- 请求示例: - ```http - PUT /api/ai/chats/1/update/ - Authorization: Bearer - Content-Type: application/json - - { - "title": "更新后的标题" - } - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "更新成功", - "data": null - } - ``` - -### 删除聊天(批量) -- 方法: `POST` -- 路径: `/api/ai/chats/delete/` -- 请求体: `application/json` -- 字段: `chat_ids` 要删除的聊天ID数组(整数) -- 实现: `AIChatDeleteView.post`(`hertz_studio_django_ai/views.py:231`) -- 请求示例: - ```http - POST /api/ai/chats/delete/ - Authorization: Bearer - Content-Type: application/json - - { - "chat_ids": [1, 2, 3] - } - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "成功删除3个聊天", - "data": null - } - ``` - -### 发送消息 -- 方法: `POST` -- 路径: `/api/ai/chats/{chat_id}/send/` -- 路径参数: `chat_id` 聊天ID(整数) -- 请求体: `application/json` -- 字段: `content` 必填,消息内容,不能为空 -- 实现: `AIChatSendMessageView.post`(`hertz_studio_django_ai/views.py:280`) -- 请求示例: - ```http - POST /api/ai/chats/1/send/ - Authorization: Bearer - Content-Type: application/json - - { - "content": "你好,请介绍一下Python的特点" - } - ``` -- 成功返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "user_message": { - "id": 5, - "role": "user", - "content": "你好,请介绍一下Python的特点", - "created_at": "2024-01-15 11:30:00" - }, - "ai_message": { - "id": 6, - "role": "assistant", - "content": "Python是一种高级编程语言,语法简洁,生态丰富...", - "created_at": "2024-01-15 11:30:05" - } - } - } - ``` -- 失败返回示例(参数验证失败): - ```json - { - "success": false, - "code": 422, - "message": "参数验证失败", - "data": { - "content": ["消息内容不能为空"] - } - } - ``` -- 失败返回示例(聊天不存在): - ```json - { - "success": false, - "code": 404, - "message": "聊天不存在或无权访问", - "data": null - } - ``` -- 失败返回示例(AI生成失败): - ```json - { - "success": false, - "code": 500, - "message": "AI回复生成失败:服务不可用", - "data": null - } - ``` - -## 附注 -- 列表与详情时间字段为字符串格式 `YYYY-MM-DD HH:mm:ss`(`hertz_studio_django_ai/views.py:65`、`hertz_studio_django_ai/views.py:151`)。 -- AI模型名称由配置项 `settings.AI_MODEL_NAME` 控制,默认 `deepseek-r1:1.5b`(`hertz_studio_django_ai/views.py:263`)。 \ No newline at end of file diff --git a/docs/API接口文档/Wiki模块接口文档.md b/docs/API接口文档/Wiki模块接口文档.md deleted file mode 100644 index a35e477..0000000 --- a/docs/API接口文档/Wiki模块接口文档.md +++ /dev/null @@ -1,367 +0,0 @@ -# Hertz Studio Django Wiki 接口文档 - -- 基础路径: `/api/wiki/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/wiki/', include('hertz_studio_django_wiki.urls'))` 挂载(`hertz_server_django/urls.py:29`)。 -- 认证说明: 标注“需要登录”的接口需在请求头携带 `Authorization: Bearer `(`hertz_studio_django_auth/utils/decorators/auth_decorators.py:1`)。 - - -## 一、知识分类管理 - -### (1)获取分类列表 -- 方法: `GET` -- 路径: `/api/wiki/categories/` -- 认证: 不需要 -- 查询参数: - - `page`: 页码,默认 `1` - - `page_size`: 每页数量,默认 `10` - - `name`: 分类名称关键字 - - `parent_id`: 父分类ID(`0` 表示顶级) - - `is_active`: `true/false` -- 实现: `wiki_category_list`(`hertz_studio_django_wiki/views.py:41`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/wiki/categories/?page=1&page_size=10&name=技术" - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "list": [ - { - "id": 1, - "name": "技术文档", - "description": "技术相关文档", - "parent": null, - "parent_name": null, - "sort_order": 1, - "is_active": true, - "created_at": "2024-01-01T10:00:00Z", - "updated_at": "2024-01-01T10:00:00Z", - "children_count": 3, - "articles_count": 15, - "full_path": "技术文档" - } - ], - "total": 1, - "page": 1, - "page_size": 10 - } - } - ``` - -### (2)获取树形分类 -- 方法: `GET` -- 路径: `/api/wiki/categories/tree/` -- 认证: 不需要 -- 实现: `wiki_category_tree`(`hertz_studio_django_wiki/views.py:101`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/wiki/categories/tree/" - ``` -- 返回示例: - ```json - [ - { - "id": 1, - "name": "技术文档", - "description": "技术相关文档", - "sort_order": 1, - "is_active": true, - "articles_count": 15, - "children": [ - {"id": 2, "name": "后端", "children": []} - ] - } - ] - ``` - -### (3)创建分类 -- 方法: `POST` -- 路径: `/api/wiki/categories/create/` -- 认证: 不需要 -- 请求体: `application/json` -- 字段: `name`, `description`, `parent`, `sort_order`, `is_active` -- 实现: `wiki_category_create`(`hertz_studio_django_wiki/views.py:136`) -- 请求示例: - ```http - POST /api/wiki/categories/create/ - Content-Type: application/json - - { - "name": "新分类", - "description": "分类描述", - "parent": null, - "sort_order": 10, - "is_active": true - } - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "知识分类创建成功", - "data": { - "id": 5, - "name": "新分类", - "description": "分类描述", - "parent": null, - "parent_name": null, - "sort_order": 10, - "is_active": true, - "created_at": "2025-11-17T10:00:00Z", - "updated_at": "2025-11-17T10:00:00Z", - "children_count": 0, - "articles_count": 0, - "full_path": "新分类" - } - } - ``` - -### (4)分类详情 -- 方法: `GET` -- 路径: `/api/wiki/categories/{category_id}/` -- 认证: 不需要 -- 实现: `wiki_category_detail`(`hertz_studio_django_wiki/views.py:178`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/wiki/categories/1/" - ``` -- 返回示例: 同“获取分类列表”中的单项结构。 - -### (5)更新分类 -- 方法: `PUT`(支持部分更新) -- 路径: `/api/wiki/categories/{category_id}/update/` -- 认证: 不需要 -- 请求体: `application/json` -- 可更新字段: `name`, `description`, `parent`, `sort_order`, `is_active` -- 实现: `wiki_category_update`(`hertz_studio_django_wiki/views.py:220`) -- 请求示例: - ```http - PUT /api/wiki/categories/1/update/ - Content-Type: application/json - - { - "name": "更新后的分类名", - "description": "更新后的描述", - "sort_order": 20 - } - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "知识分类更新成功", "data": {"id": 1, "name": "更新后的分类名"}} - ``` - -### (6)删除分类 -- 方法: `DELETE` -- 路径: `/api/wiki/categories/{category_id}/delete/` -- 认证: 不需要 -- 行为: 软删除(将 `is_active=false`);若存在子分类或文章将返回错误 -- 实现: `wiki_category_delete`(`hertz_studio_django_wiki/views.py:270`) -- 请求示例: - ```bash - curl -X DELETE "http://localhost:8000/api/wiki/categories/1/delete/" - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "知识分类删除成功"} - ``` - - -## 二、知识文章管理 - -### (1)获取文章列表 -- 方法: `GET` -- 路径: `/api/wiki/articles/` -- 认证: 不需要 -- 查询参数: - - `page`, `page_size` - - `title`: 标题关键字 - - `category_id`: 分类ID - - `author_id`: 作者ID - - `status`: `draft|published|archived` -- 实现: `wiki_article_list`(`hertz_studio_django_wiki/views.py:318`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/wiki/articles/?page=1&page_size=10&category_id=1&status=published" - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "list": [ - { - "id": 101, - "title": "如何部署Django", - "summary": "部署流程概览", - "image": null, - "category_name": "技术文档", - "author_name": "alice", - "status": "published", - "status_display": "已发布", - "view_count": 42, - "created_at": "2025-11-01T09:00:00Z", - "updated_at": "2025-11-10T09:00:00Z", - "published_at": "2025-11-10T09:00:00Z" - } - ], - "total": 1, - "page": 1, - "page_size": 10 - } - } - ``` - -### (2)创建文章(需要登录) -- 方法: `POST` -- 路径: `/api/wiki/articles/create/` -- 认证: 需要登录(`Authorization: Bearer `) -- 请求体: `application/json` -- 字段: `title`, `content`, `summary`, `image`, `category`, `status`, `tags`, `sort_order` -- 实现: `wiki_article_create`(`hertz_studio_django_wiki/views.py:384`) -- 请求示例: - ```http - POST /api/wiki/articles/create/ - Authorization: Bearer - Content-Type: application/json - - { - "title": "新文章", - "content": "文章内容...", - "summary": "文章摘要...", - "image": null, - "category": 1, - "status": "draft", - "tags": "django,部署", - "sort_order": 10 - } - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "知识文章创建成功", - "data": { - "id": 102, - "title": "新文章", - "content": "文章内容...", - "summary": "文章摘要...", - "image": null, - "category": 1, - "category_name": "技术文档", - "author": 2, - "author_name": "alice", - "status": "draft", - "status_display": "草稿", - "tags": "django,部署", - "tags_list": ["django", "部署"], - "view_count": 0, - "sort_order": 10, - "created_at": "2025-11-17T11:00:00Z", - "updated_at": "2025-11-17T11:00:00Z", - "published_at": null - } - } - ``` - -### (3)文章详情 -- 方法: `GET` -- 路径: `/api/wiki/articles/{article_id}/` -- 认证: 不需要 -- 行为: 获取详情并增加浏览量 -- 实现: `wiki_article_detail`(`hertz_studio_django_wiki/views.py:426`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/wiki/articles/102/" - ``` -- 返回示例: 同“创建文章”中的完整字段结构,`view_count` 将递增。 - -### (4)更新文章 -- 方法: `PUT`(支持部分更新) -- 路径: `/api/wiki/articles/{article_id}/update/` -- 认证: 不需要 -- 请求体: `application/json` -- 可更新字段: `title`, `content`, `summary`, `image`, `category`, `status`, `tags`, `sort_order` -- 实现: `wiki_article_update`(`hertz_studio_django_wiki/views.py:472`) -- 请求示例: - ```http - PUT /api/wiki/articles/102/update/ - Content-Type: application/json - - { - "status": "published", - "summary": "更新后的摘要" - } - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "知识文章更新成功", "data": {"id": 102, "status": "published"}} - ``` - -### (5)删除文章 -- 方法: `DELETE` -- 路径: `/api/wiki/articles/{article_id}/delete/` -- 认证: 不需要 -- 实现: `wiki_article_delete`(`hertz_studio_django_wiki/views.py:518`) -- 请求示例: - ```bash - curl -X DELETE "http://localhost:8000/api/wiki/articles/102/delete/" - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "知识文章删除成功"} - ``` - -### (6)发布文章 -- 方法: `POST` -- 路径: `/api/wiki/articles/{article_id}/publish/` -- 认证: 不需要 -- 实现: `wiki_article_publish`(`hertz_studio_django_wiki/views.py:561`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/wiki/articles/102/publish/" - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "知识文章发布成功"} - ``` -- 失败示例(已发布再次发布): - ```json - {"success": false, "code": 500, "message": "系统错误", "error": "文章已经是发布状态"} - ``` - -### (7)归档文章 -- 方法: `POST` -- 路径: `/api/wiki/articles/{article_id}/archive/` -- 认证: 不需要 -- 实现: `wiki_article_archive`(`hertz_studio_django_wiki/views.py:609`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/wiki/articles/102/archive/" - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "知识文章归档成功"} - ``` - - -## 三、备注 -- 文章状态枚举: `draft|published|archived`(`hertz_studio_django_wiki/models.py:35`)。 -- 分类软删除通过 `is_active=false` 实现;删除校验会阻止删除存在子分类或文章的分类(`views.py:270`)。 -- 文章详情接口会递增 `view_count`(`hertz_studio_django_wiki/models.py:70` 和 `views.py:431`)。 \ No newline at end of file diff --git a/docs/API接口文档/YOLO模块接口文档.md b/docs/API接口文档/YOLO模块接口文档.md deleted file mode 100644 index 0a9d1ec..0000000 --- a/docs/API接口文档/YOLO模块接口文档.md +++ /dev/null @@ -1,461 +0,0 @@ -# Hertz Studio Django YOLO 接口文档 - -- 基础路径: `/api/yolo/` -- 统一响应: 使用 `HertzResponse` 封装,结构为: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` - 参考 `hertz_studio_django_utils/responses/HertzResponse.py` -- 认证说明: 标注“需要登录”的接口需在请求头携带 `Authorization: Bearer `,验证逻辑参考 `hertz_studio_django_auth/utils/decorators/auth_decorators.py`。 - - -## 一、模型上传与转换 - -### (1)上传模型(压缩包或文件夹) -- 方法: `POST` -- 路径: `/api/yolo/upload/` -- 认证: 不需要 -- 请求类型: `multipart/form-data` -- 参数: - - `zip_file`: ZIP 压缩包文件,与 `folder_files` 互斥 - - `folder_files[...]`: 文件夹内的多个文件(键名形如 `folder_files[path/to/file]`),与 `zip_file` 互斥 - - `name`: 模型名称(必填) - - `version`: 模型版本(默认 `1.0`) - - `description`: 模型描述(可选) -- 参考实现: `views.py` 中 `model_upload`,`_handle_zip_upload`,`_handle_folder_upload`,`_validate_and_create_model`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1018,1058,1129,1182) -- 示例请求(ZIP 上传): - ```bash - curl -X POST "http://localhost:8000/api/yolo/upload/" \ - -F "zip_file=@/path/to/yolo_model.zip" \ - -F "name=MyYolo" \ - -F "version=1.0" \ - -F "description=Demo model" - ``` -- 示例响应: - ```json - { - "success": true, - "code": 200, - "message": "模型上传成功", - "data": { - "id": 12, - "name": "MyYolo", - "version": "1.0", - "folder_path": "/absolute/path/to/media/models/MyYolo_xxxxxxxx", - "weights_path": "/absolute/path/to/media/models/MyYolo_xxxxxxxx/weights", - "model_path": "/absolute/path/to/media/models/MyYolo_xxxxxxxx/weights/best.pt", - "best_model_path": "/absolute/path/to/media/models/.../weights/best.pt", - "last_model_path": "/absolute/path/to/media/models/.../weights/last.pt", - "categories": {"0":"person","1":"bicycle"}, - "created_at": "2025-10-22T03:20:00Z" - } - } - ``` - -### (2)上传 .pt 并转换为 ONNX -- 方法: `POST` -- 路径: `/api/yolo/onnx/upload/` -- 认证: 不需要 -- 请求类型: `multipart/form-data` -- 参数: - - `file`: `.pt` 模型文件(必填) - - `imgsz`: 导出图像尺寸(如 `640` 或 `640,640`,默认 `640`) - - `opset`: ONNX opset 版本(默认 `12`) - - `simplify`: 是否简化 ONNX(`true/false`,默认 `false`) -- 参考实现: `views.py` 中 `upload_pt_convert_onnx`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:877) -- 示例请求: - ```bash - curl -X POST "http://localhost:8000/api/yolo/onnx/upload/" \ - -F "file=@/path/to/best.pt" \ - -F "imgsz=640" \ - -F "opset=12" \ - -F "simplify=true" - ``` -- 示例响应: - ```json - { - "success": true, - "code": 200, - "message": "ONNX 导出成功", - "data": { - "onnx_relative_path": "yolo/ONNX/best_xxxxxxxx.onnx", - "download_url": "http://localhost:8000/media/yolo/ONNX/best_xxxxxxxx.onnx", - "labels_relative_path": "yolo/ONNX/best_xxxxxxxx.labels.json", - "labels_download_url": "http://localhost:8000/media/yolo/ONNX/best_xxxxxxxx.labels.json" - } - } - ``` - - -## 二、模型管理 - -### (1)获取模型列表 -- 方法: `GET` -- 路径: `/api/yolo/models/` -- 认证: 不需要 -- 参考实现: `model_list`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:101) -- 示例响应: - ```json - { - "success": true, - "code": 200, - "message": "获取模型列表成功", - "data": [ - {"id": 1, "name": "ModelA", "version": "1.0", "is_enabled": true, "created_at": "2025-10-22T03:20:00Z"} - ] - } - ``` - -### (2)获取模型详情 -- 方法: `GET` -- 路径: `/api/yolo/models/{pk}/` -- 认证: 不需要 -- 参考实现: `model_detail`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:121) -- 示例响应(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取模型详情成功", - "data": { - "id": 1, - "name": "ModelA", - "version": "1.0", - "model_file": "/media/yolo/models/modela.pt", - "model_folder_path": "/abs/path/models/modela_...", - "model_path": "/abs/path/models/modela_/weights/best.pt", - "weights_folder_path": "/abs/path/models/modela_/weights", - "categories": {"0": "person"}, - "is_enabled": true, - "description": "...", - "created_at": "...", - "updated_at": "..." - } - } - ``` - -### (2)更新模型 -- 方法: `PUT` 或 `PATCH` -- 路径: `/api/yolo/models/{pk}/update/` -- 认证: 不需要 -- 请求类型: `application/json` 或 `multipart/form-data` -- 可更新字段: `description`, `is_enabled`,(如上传 `model_file` 必须为 `.pt`) -- 参考实现: `model_update`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:134) -- 示例请求(PATCH): - ```http - PATCH /api/yolo/models/1/update/ - Content-Type: application/json - - { - "description": "更新描述", - "is_enabled": true - } - ``` -- 示例响应: - ```json - {"success": true, "code": 200, "message": "模型更新成功", "data": {"id": 1, "name": "ModelA", "is_enabled": true}} - ``` - -### (3)删除模型 -- 方法: `DELETE` -- 路径: `/api/yolo/models/{pk}/delete/` -- 认证: 不需要 -- 参考实现: `model_delete`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:152) -- 示例响应: - ```json - {"success": true, "code": 200, "message": "模型删除成功"} - ``` - -### (4)启用指定模型 -- 方法: `POST` -- 路径: `/api/yolo/models/{pk}/enable/` -- 认证: 不需要 -- 行为: 先禁用其他模型,再启用当前模型 -- 参考实现: `model_enable`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:165) -- 示例响应: - ```json - {"success": true, "code": 200, "message": "模型 ModelA 已启用", "data": {"id": 1, "is_enabled": true}} - ``` - -### (5)获取当前启用的模型 -- 方法: `GET` -- 路径: `/api/yolo/models/enabled/` -- 认证: 不需要 -- 参考实现: `model_enabled`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:186) -- 示例响应: - ```json - {"success": true, "code": 200, "message": "获取启用模型成功", "data": {"id": 1, "name": "ModelA"}} - ``` - -### (6)创建模型(占位) -- 方法: `POST` -- 路径: `/api/yolo/models/create/` -- 说明: 返回 405,提示使用 `/api/yolo/upload/` -- 参考实现: `model_create`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:114) - - -## 三、模型类别管理 - -> 提示:类别通常随模型上传自动导入。手动创建/删除不推荐,仅保留接口。 - -### (1)获取类别列表 -- 方法: `GET` -- 路径: `/api/yolo/categories/` -- 认证: 不需要 -- 参考实现: `category_list`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1301) - -### (2)获取类别详情 -- 方法: `GET` -- 路径: `/api/yolo/categories/{pk}/` -- 认证: 不需要 -- 参考实现: `category_detail`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1329) - -### (3)更新类别 -- 方法: `PUT` 或 `PATCH` -- 路径: `/api/yolo/categories/{pk}/update/` -- 认证: 不需要 -- 请求类型: `application/json` -- 可更新字段: `alias`, `alert_level`(`high|medium|low|none`), `is_active` -- 参考实现: `category_update`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1342) -- 示例请求(PATCH): - ```http - PATCH /api/yolo/categories/10/update/ - Content-Type: application/json - - {"alias": "行人", "alert_level": "high", "is_active": true} - ``` -- 示例响应: - ```json - {"success": true, "code": 200, "message": "更新类别成功", "data": {"id": 10, "alias": "行人", "alert_level": "high"}} - ``` - -### (4)切换类别启用状态 -- 方法: `POST` -- 路径: `/api/yolo/categories/{pk}/toggle-status/` -- 认证: 不需要 -- 参考实现: `category_toggle_status`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1385) -- 示例响应: - ```json - {"success": true, "code": 200, "message": "类别 'person' 启用成功", "data": {"is_active": true}} - ``` - -### (5)获取启用的类别列表 -- 方法: `GET` -- 路径: `/api/yolo/categories/active/` -- 认证: 不需要 -- 参考实现: `category_active_list`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1405) - -### (6)创建类别(不推荐) -- 方法: `POST` -- 路径: `/api/yolo/categories/create/` -- 认证: 不需要 -- 请求类型: `application/json` -- 参考实现: `category_create`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1312) - -### (7)删除类别(不推荐) -- 方法: `DELETE` -- 路径: `/api/yolo/categories/{pk}/delete/` -- 认证: 不需要 -- 参考实现: `category_delete`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:1362) - - -## 四、目标检测 - -### 执行检测 -- 方法: `POST` -- 路径: `/api/yolo/detect/` -- 认证: 需要登录(`Authorization: Bearer `) -- 请求类型: `multipart/form-data` -- 参数: - - `file`: 要检测的图片或视频文件(支持图片:`.jpg,.jpeg,.png,.bmp,.tiff,.webp`;视频:`.mp4,.avi,.mov,.mkv,.wmv,.flv`) - - `model_id`: 指定模型ID(可选,未提供则使用当前启用模型) - - `confidence_threshold`: 置信度阈值(默认 `0.5`,范围 `0.1-1.0`) -- 参考实现: `yolo_detection`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:446) -- 示例请求: - ```bash - curl -X POST "http://localhost:8000/api/yolo/detect/" \ - -H "Authorization: Bearer " \ - -F "file=@/path/to/image.jpg" \ - -F "model_id=1" \ - -F "confidence_threshold=0.5" - ``` -- 示例响应: - ```json - { - "success": true, - "code": 200, - "message": "检测完成", - "data": { - "detection_id": 1001, - "result_file_url": "/media/detection/result/result_xxx.jpg", - "original_file_url": "/media/detection/original/uuid_image.jpg", - "object_count": 3, - "detected_categories": ["person"], - "confidence_scores": [0.91, 0.87, 0.79], - "avg_confidence": 0.8567, - "processing_time": 0.43, - "model_used": "ModelA 1.0", - "confidence_threshold": 0.5, - "user_id": 2, - "user_name": "alice", - "alert_level": "medium" - } - } - ``` - - -## 五、检测记录 - -### (1)获取检测记录列表 -- 方法: `GET` -- 路径: `/api/yolo/detections/` -- 认证: 不需要 -- 查询参数: - - `type`: `image` 或 `video` - - `model_id`: 模型ID - - `user_id`: 用户ID -- 参考实现: `detection_list`(d:\All_template\yolo\hertz_studio_django_yolo\views.py:204) -- 示例响应(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取检测记录列表成功", - "data": [ - { - "id": 1001, - "original_file": "/media/detection/original/uuid_image.jpg", - "result_file": "/media/detection/result/result_xxx.jpg", - "original_filename": "uuid_image.jpg", - "result_filename": "result_xxx.jpg", - "detection_type": "image", - "model_name": "ModelA 1.0", - "model_info": {"id":1, "name":"ModelA", "version":"1.0"}, - "object_count": 3, - "detected_categories": ["person"], - "confidence_threshold": 0.5, - "confidence_scores": [0.91, 0.87, 0.79], - "avg_confidence": 0.8567, - "processing_time": 0.43, - "created_at": "..." - } - ] - } - ``` - -### (2)获取指定用户的检测记录 -- 方法: `GET` -- 路径: `/api/yolo/detections/{user_id}/user/` -- 认证: 不需要 -- 查询参数同上 -- 参考实现: `user_detection_records`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:231) - -### (3)获取检测记录详情 -- 方法: `GET` -- 路径: `/api/yolo/detections/{pk}/` -- 认证: 不需要 -- 参考实现: `detection_detail`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:253) - -### (4)删除检测记录 -- 方法: `DELETE` -- 路径: `/api/yolo/detections/{pk}/delete/` -- 认证: 不需要 -- 行为: 同时删除其关联的原始文件、结果文件及关联的告警 -- 参考实现: `detection_delete`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:265) -- 示例响应: - ```json - {"success": true, "code": 200, "message": "检测记录删除成功"} - ``` - -### (5)批量删除检测记录 -- 方法: `POST` -- 路径: `/api/yolo/detections/batch-delete/` -- 认证: 不需要 -- 请求类型: `application/json` -- 请求体: - ```json - {"ids": [1001, 1002, 1003]} - ``` -- 参考实现: `detection_batch_delete`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:299) -- 示例响应: - ```json - { - "success": true, - "code": 200, - "message": "成功删除 3 条检测记录", - "data": { - "deleted_count": 3, - "found_ids": ["1001","1002","1003"], - "not_found_ids": [] - } - } - ``` - -### (6)检测统计 -- 方法: `GET` -- 路径: `/api/yolo/stats/` -- 认证: 不需要 -- 参考实现: `detection_stats`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:840) - - -## 六、告警记录 - -### (1)获取告警记录列表(管理员) -- 方法: `GET` -- 路径: `/api/yolo/alerts/` -- 认证: 不需要 -- 查询参数: - - `status`: 默认 `pending`;传 `all` 表示不过滤 - - `level`: 告警等级(`high|medium|low|none`) - - `user_id`: 用户ID - - `alter_category`: 告警类别关键字(注意字段名为 `alter_category`) -- 参考实现: `alert_list`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:358) - -### (2)获取用户的告警记录 -- 方法: `GET` -- 路径: `/api/yolo/users/{user_id}/alerts/` -- 认证: 需要登录(仅本人或管理员可查) -- 查询参数: - - `status`: `pending|is_confirm|false_positive|all` - - `level`: `high|medium|low|none` - - `category`: 类别关键字 -- 参考实现: `user_alert_records`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:391) - -### (3)更新告警状态 -- 方法: `PUT` 或 `PATCH` -- 路径: `/api/yolo/alerts/{pk}/update-status/` -- 认证: 不需要 -- 请求类型: `application/json` -- 请求体: - ```json - {"status": "is_confirm"} - ``` -- 可选值: `pending`, `is_confirm`, `false_positive` -- 参考实现: `alert_update_status`(d:\AllTemplate\yolo\hertz_studio_django_yolo\views.py:426) -- 示例响应: - ```json - { - "success": true, - "code": 200, - "message": "更新告警状态成功", - "data": { - "id": 555, - "status": "is_confirm", - "alert_level": "medium", - "alert_category": "person", - "alert_level_display": "中" - } - } - ``` - - -## 七、备注 -- 所有文件型字段响应通常包含可直接访问的媒体 URL,媒体服务由 `MEDIA_URL=/media/` 提供。 -- 分类的告警等级枚举参考 `ModelCategory.ALERT_LEVELS` 与 `Alert.ALERT_LEVELS`(d:\AllTemplate\yolo\hertz_studio_django_yolo\models.py:118,153)。 -- 检测请求的文件大小限制:图片 ≤ 50MB,视频 ≤ 500MB(d:\AllTemplate\yolo\hertz_studio_django_yolo\serializers.py:99)。 - diff --git a/docs/API接口文档/代码生成模块接口文档.md b/docs/API接口文档/代码生成模块接口文档.md deleted file mode 100644 index 3ff6c2d..0000000 --- a/docs/API接口文档/代码生成模块接口文档.md +++ /dev/null @@ -1,123 +0,0 @@ -# Hertz Studio Django 代码生成模块接口文档 - -- 基础路径: `/api/codegen/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/codegen/', include(('hertz_studio_django_codegen.urls', 'hertz_studio_django_codegen'), namespace='codegen'))` 挂载(`hertz_server_django/urls.py:51`)。 -- 认证说明: 接口需要登录,需在请求头携带 `Authorization: Bearer `(`hertz_studio_django_codegen/views/simple_generator_views.py:136`)。 - - -## 一、简化代码生成 - -### 生成应用代码与菜单配置 -- 方法: `POST` -- 路径: `/api/codegen/simple/generate/` -- 认证: 需要登录 -- 实现: `simple_code_generate`(`hertz_studio_django_codegen/views/simple_generator_views.py:136`) -- 请求体: `application/json` -- 字段: - - `module_name` 字符串,模块中文名,例如 `产品管理` - - `model_name` 字符串,模型英文名,例如 `Product` - - `app_name` 字符串,Django 应用名称,例如 `hertz_studio_django_product` - - `fields` 数组,字段定义列表,每项包含:`name`、`type`、`verbose_name`、`help_text`、`required`、`max_length`、`choices` - - `operations` 数组,支持的操作,默认 `['list','create','update','delete']` - - `menu_config` 对象,菜单配置:`enabled`、`parent_code`、`prefix`、`sort_order`、`icon` - - `generate_app` 布尔,是否生成应用代码,默认 `true` - - `generate_menu` 布尔,是否生成菜单配置,默认 `true` -- 请求示例(同时生成应用与菜单): - ```http - POST /api/codegen/simple/generate/ - Authorization: Bearer - Content-Type: application/json - - { - "module_name": "产品管理", - "model_name": "Product", - "app_name": "hertz_studio_django_product", - "fields": [ - {"name": "name", "type": "CharField", "verbose_name": "产品名称", "required": true, "max_length": 100}, - {"name": "price", "type": "FloatField", "verbose_name": "价格", "required": true}, - {"name": "status", "type": "IntegerField", "verbose_name": "状态", "choices": [[0,"下线"],[1,"上线"]]} - ], - "operations": ["list","create","update","delete"], - "menu_config": {"enabled": true, "parent_code": "system", "prefix": "product", "sort_order": 10, "icon": "box"}, - "generate_app": true, - "generate_menu": true - } - ``` -- 返回示例(成功): - ```json - { - "success": true, - "code": 200, - "message": "成功生成产品管理模块代码和菜单配置", - "data": { - "generated_files": { - "hertz_studio_django_product/models.py": "...", - "hertz_studio_django_product/serializers.py": "...", - "hertz_studio_django_product/views.py": "...", - "hertz_studio_django_product/urls.py": "...", - "hertz_studio_django_product/apps.py": "..." - }, - "menu_configs": [ - {"code": "product", "name": "产品管理", "type": "menu", "sort_order": 10}, - {"code": "product:list", "name": "产品列表", "type": "permission", "sort_order": 11} - ], - "app_path": "hertz_studio_django_product", - "menu_file": "d:/All_template/yolo/pending_menus_product.py" - } - } - ``` -- 请求示例(仅生成菜单配置): - ```http - POST /api/codegen/simple/generate/ - Authorization: Bearer - Content-Type: application/json - - { - "module_name": "库存管理", - "model_name": "Inventory", - "app_name": "hertz_studio_django_inventory", - "fields": [], - "operations": ["list"], - "menu_config": {"enabled": true, "parent_code": "system", "prefix": "inventory"}, - "generate_app": false, - "generate_menu": true - } - ``` -- 返回示例(仅菜单): - ```json - { - "success": true, - "code": 200, - "message": "成功生成库存管理模块代码和菜单配置", - "data": { - "generated_files": {}, - "menu_configs": [{"code": "inventory", "name": "库存管理", "type": "menu"}], - "app_path": "", - "menu_file": "d:/All_template/yolo/pending_menus_inventory.py" - } - } - ``` - - -## 二、错误响应示例 -- 缺少必填参数: - ```json - {"success": false, "code": 422, "message": "缺少必填参数: fields"} - ``` -- 请求体格式错误(非JSON): - ```json - {"success": false, "code": 422, "message": "请求参数格式错误,请使用JSON格式"} - ``` -- 生成失败(异常信息): - ```json - {"success": false, "code": 500, "message": "代码生成失败: <错误信息>"} - ``` \ No newline at end of file diff --git a/docs/API接口文档/日志模块接口文档.md b/docs/API接口文档/日志模块接口文档.md deleted file mode 100644 index 25cbf9f..0000000 --- a/docs/API接口文档/日志模块接口文档.md +++ /dev/null @@ -1,121 +0,0 @@ -# Hertz Studio Django 日志模块接口文档 - -- 基础路径: `/api/log/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/log/', include('hertz_studio_django_log.urls'))` 挂载(`hertz_server_django/urls.py:46`)。 -- 认证与权限: 接口需要登录并具备相应权限,其中列表接口需 `system:log:list`,详情接口需 `system:log:query`(`hertz_studio_django_log/views/log_views.py:18`, `225`)。 - - -## 一、操作日志列表 - -### 获取操作日志列表 -- 方法: `GET` -- 路径: `/api/log/list/` -- 权限: 仅管理员(`system:log:list`) -- 查询参数: - - `user_id`: 用户ID - - `username`: 用户名(模糊匹配) - - `action_type`: 操作类型(`create|update|delete|view|list|login|logout|export|import|other`) - - `module`: 操作模块(模糊匹配) - - `status`: 操作状态(`0|1`) - - `ip_address`: IP地址 - - `start_date`: 开始时间(ISO日期时间) - - `end_date`: 结束时间(ISO日期时间) - - `page`: 页码,默认 `1` - - `page_size`: 每页数量,默认 `20` -- 实现: `operation_log_list`(`hertz_studio_django_log/views/log_views.py:117`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/log/list/?username=admin&action_type=login&page=1&page_size=10" \ - -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取操作日志列表成功", - "data": { - "logs": [ - { - "log_id": 1001, - "username": "admin", - "action_type": "login", - "action_type_display": "登录", - "module": "认证", - "description": "用户登录", - "ip_address": "127.0.0.1", - "response_status": 200, - "status": 1, - "status_display": "成功", - "is_success": true, - "created_at": "2025-11-17T08:30:00Z" - } - ], - "pagination": { - "page": 1, - "page_size": 10, - "total_count": 1, - "total_pages": 1, - "has_next": false, - "has_previous": false - } - } - } - ``` - - -## 二、操作日志详情 - -### 获取操作日志详情 -- 方法: `GET` -- 路径: `/api/log/detail/{log_id}/` -- 权限: 仅管理员(`system:log:query`) -- 路径参数: `log_id` 日志ID -- 实现: `operation_log_detail`(`hertz_studio_django_log/views/log_views.py:259`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/log/detail/1001/" \ - -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取操作日志详情成功", - "data": { - "log_id": 1001, - "user": 1, - "username": "admin", - "action_type": "login", - "action_type_display": "登录", - "module": "认证", - "description": "用户登录", - "target_model": null, - "target_id": null, - "ip_address": "127.0.0.1", - "user_agent": "Mozilla/5.0", - "request_data": {"username": "admin"}, - "formatted_request_data": "{\n \"username\": \"admin\"\n}", - "response_status": 200, - "status": 1, - "status_display": "成功", - "is_success": true, - "created_at": "2025-11-17T08:30:00Z" - } - } - ``` - - -## 三、备注 -- 列表返回采用 `OperationLogListSerializer` 字段子集以简化展示。 -- 详情返回使用 `OperationLogSerializer`,包含格式化的请求数据与成功判断等辅助字段。 \ No newline at end of file diff --git a/docs/API接口文档/系统监控模块接口文档.md b/docs/API接口文档/系统监控模块接口文档.md deleted file mode 100644 index 55f7cc0..0000000 --- a/docs/API接口文档/系统监控模块接口文档.md +++ /dev/null @@ -1,302 +0,0 @@ -# Hertz Studio Django 系统监控模块接口文档 - -- 基础路径: `/api/system/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/system/', include('hertz_studio_django_system_monitor.urls'))` 挂载(`hertz_server_django/urls.py:23`)。 -- 认证说明: 所有接口均需要登录,需在请求头携带 `Authorization: Bearer `(`hertz_studio_django_auth/utils/decorators/auth_decorators.py:1`)。 - - -## 一、系统信息 - -### 获取系统信息 -- 方法: `GET` -- 路径: `/api/system/system/` -- 认证: 需要登录 -- 实现: `SystemInfoView.get`(`hertz_studio_django_system_monitor/views.py:63`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/system/" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "hostname": "DESKTOP-ABC123", - "platform": "Windows-10-10.0.19041-SP0", - "architecture": "AMD64", - "boot_time": "2025-11-16T08:30:00Z", - "uptime": "2 days, 14:30:00" - } - } - ``` - - -## 二、CPU 信息 - -### 获取CPU信息 -- 方法: `GET` -- 路径: `/api/system/cpu/` -- 认证: 需要登录 -- 实现: `CPUInfoView.get`(`hertz_studio_django_system_monitor/views.py:63`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/cpu/" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "cpu_count": 8, - "cpu_percent": 25.6, - "cpu_freq": {"current": 2400.0, "min": 800.0, "max": 3600.0}, - "load_avg": [1.2, 1.5, 1.8] - } - } - ``` - - -## 三、内存信息 - -### 获取内存信息 -- 方法: `GET` -- 路径: `/api/system/memory/` -- 认证: 需要登录 -- 实现: `MemoryInfoView.get`(`hertz_studio_django_system_monitor/views.py:94`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/memory/" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "total": 17179869184, - "available": 8589934592, - "used": 8589934592, - "percent": 50.0, - "free": 8589934592 - } - } - ``` - - -## 四、磁盘信息 - -### 获取磁盘信息 -- 方法: `GET` -- 路径: `/api/system/disks/` -- 认证: 需要登录 -- 实现: `DiskInfoView.get`(`hertz_studio_django_system_monitor/views.py:127`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/disks/" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": [ - { - "device": "C:\\", - "mountpoint": "C:\\", - "fstype": "NTFS", - "total": 1073741824000, - "used": 536870912000, - "free": 536870912000, - "percent": 50.0 - } - ] - } - ``` - - -## 五、网络信息 - -### 获取网络信息 -- 方法: `GET` -- 路径: `/api/system/network/` -- 认证: 需要登录 -- 实现: `NetworkInfoView.get`(`hertz_studio_django_system_monitor/views.py:170`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/network/" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": [ - { - "interface": "以太网", - "bytes_sent": 1048576000, - "bytes_recv": 2097152000, - "packets_sent": 1000000, - "packets_recv": 1500000 - } - ] - } - ``` - - -## 六、进程信息 - -### 获取进程信息 -- 方法: `GET` -- 路径: `/api/system/processes/` -- 认证: 需要登录 -- 查询参数: - - `limit`: 返回条数,默认 `20` - - `sort_by`: 排序字段,默认 `cpu_percent`(可选 `cpu_percent|memory_percent|create_time`) -- 实现: `ProcessInfoView.get`(`hertz_studio_django_system_monitor/views.py:204`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/processes/?limit=10&sort_by=cpu_percent" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": [ - { - "pid": 1234, - "name": "python.exe", - "status": "running", - "cpu_percent": 15.6, - "memory_percent": 8.2, - "memory_info": {"rss": 134217728, "vms": 268435456}, - "create_time": "2025-11-16T10:30:00Z", - "cmdline": ["python", "manage.py", "runserver"] - } - ] - } - ``` - - -## 七、GPU 信息 - -### 获取GPU信息 -- 方法: `GET` -- 路径: `/api/system/gpu/` -- 认证: 需要登录 -- 实现: `GPUInfoView.get`(`hertz_studio_django_system_monitor/views.py:259`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/gpu/" \ - -H "Authorization: Bearer " - ``` -- 返回示例(有GPU设备): - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "gpu_available": true, - "gpu_info": [ - { - "id": 0, - "name": "NVIDIA GeForce RTX 3080", - "load": 45.6, - "memory_total": 10240, - "memory_used": 4096, - "memory_util": 40.0, - "temperature": 65 - } - ], - "timestamp": "2025-11-16 18:30:00" - } - } - ``` -- 返回示例(无GPU设备): - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "gpu_available": false, - "message": "未检测到GPU设备", - "timestamp": "2025-11-16 18:30:00" - } - } - ``` -- 返回示例(GPU库不可用): - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "gpu_available": false, - "message": "GPU监控不可用,请安装GPUtil库", - "timestamp": "2025-11-16 18:30:00" - } - } - ``` - - -## 八、综合监控 - -### 获取系统监测综合信息 -- 方法: `GET` -- 路径: `/api/system/monitor/` -- 认证: 需要登录 -- 实现: `SystemMonitorView.get`(`hertz_studio_django_system_monitor/views.py:325`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/system/monitor/" \ - -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "system": {"hostname": "DESKTOP-ABC123", "platform": "Windows-10-10.0.19041-SP0", "architecture": "AMD64", "boot_time": "2025-11-16T08:30:00Z", "uptime": "2 days, 14:30:00"}, - "cpu": {"cpu_count": 8, "cpu_percent": 25.6, "cpu_freq": {"current": 2400.0, "min": 800.0, "max": 3600.0}, "load_avg": [1.2, 1.5, 1.8]}, - "memory": {"total": 17179869184, "available": 8589934592, "used": 8589934592, "percent": 50.0, "free": 8589934592}, - "disks": [{"device": "C:\\", "mountpoint": "C:\\", "fstype": "NTFS", "total": 1073741824000, "used": 536870912000, "free": 536870912000, "percent": 50.0}], - "network": [{"interface": "以太网", "bytes_sent": 1048576000, "bytes_recv": 2097152000, "packets_sent": 1000000, "packets_recv": 1500000}], - "processes": [{"pid": 1234, "name": "python.exe", "status": "running", "cpu_percent": 15.6, "memory_percent": 8.2, "memory_info": {"rss": 134217728, "vms": 268435456}, "create_time": "2025-11-16T10:30:00Z", "cmdline": ["python", "manage.py", "runserver"]}], - "gpus": [{"id": 0, "name": "NVIDIA GeForce RTX 3080", "load": 45.6, "memory_total": 10240, "memory_used": 4096, "memory_util": 40.0, "temperature": 65}] - } - } - ``` - - -## 九、错误响应示例 -- 通用错误格式: - ```json - {"success": false, "code": 401, "message": "未授权访问"} - ``` \ No newline at end of file diff --git a/docs/API接口文档/认证模块接口文档.md b/docs/API接口文档/认证模块接口文档.md deleted file mode 100644 index 990a4f6..0000000 --- a/docs/API接口文档/认证模块接口文档.md +++ /dev/null @@ -1,452 +0,0 @@ -# Hertz Studio Django 认证模块接口文档 - -- 基础路径: `/api/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/', include('hertz_studio_django_auth.urls'))` 挂载(`hertz_server_django/urls.py:31`)。 -- 认证说明: - - 登录、注册、发送邮箱验证码、重置密码:不需要登录 - - 其他接口需要在请求头携带 `Authorization: Bearer `(`hertz_studio_django_auth/utils/decorators/auth_decorators.py:24`)。 -- 路由前缀说明: - - 认证相关接口前缀为 `/api/auth/` - - 管理接口(用户/角色/菜单/部门)前缀为 `/api/` - - -## 一、用户认证 - -### (1)用户登录 -- 方法: `POST` -- 路径: `/api/auth/login/` -- 认证: 不需要登录 -- 请求体: `application/json` -- 字段: `username`, `password`, `captcha_code`, `captcha_key` -- 实现: `user_login`(`hertz_studio_django_auth/views/auth_views.py:132`) -- 请求示例: - ```http - POST /api/auth/login/ - Content-Type: application/json - - { - "username": "admin", - "password": "Passw0rd!", - "captcha_code": "A1B2", - "captcha_key": "c1a2b3c4-d5e6-7890-abcd-ef1234567890" - } - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "登录成功", - "data": { - "access_token": "eyJhbGci...", - "refresh_token": "eyJhbGci...", - "user_info": { - "user_id": 1, - "username": "admin", - "email": "admin@example.com", - "phone": "13800000000", - "real_name": "管理员", - "avatar": null, - "roles": [{"role_id": 1, "role_name": "管理员", "role_code": "admin"}], - "permissions": ["system:user:list", "system:role:list"] - } - } - } - ``` - -### (2)用户注册 -- 方法: `POST` -- 路径: `/api/auth/register/` -- 认证: 不需要登录 -- 请求体: `application/json` -- 字段: `username`, `password`, `confirm_password`, `email`, `phone`, `real_name`, `email_code?` -- 实现: `user_register`(`hertz_studio_django_auth/views/auth_views.py:131`) -- 请求示例: - ```http - POST /api/auth/register/ - Content-Type: application/json - - { - "username": "newuser", - "password": "Passw0rd!", - "confirm_password": "Passw0rd!", - "email": "new@example.com", - "phone": "13800000001", - "real_name": "新用户", - "email_code": "123456" - } - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "注册成功", "data": {"user_id": 12, "username": "newuser"}} - ``` - -### (3)用户登出 -- 方法: `POST` -- 路径: `/api/auth/logout/` -- 认证: 需要登录 -- 实现: `user_logout`(`hertz_studio_django_auth/views/auth_views.py:184`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/auth/logout/" -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "登出成功"} - ``` - -### (4)修改密码 -- 方法: `POST` -- 路径: `/api/auth/password/change/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `old_password`, `new_password`, `confirm_password` -- 实现: `change_password`(`hertz_studio_django_auth/views/auth_views.py:214`) -- 请求示例: - ```http - POST /api/auth/password/change/ - Authorization: Bearer - Content-Type: application/json - - {"old_password": "Passw0rd!", "new_password": "N3wPass!", "confirm_password": "N3wPass!"} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "密码修改成功"} - ``` - -### (5)重置密码 -- 方法: `POST` -- 路径: `/api/auth/password/reset/` -- 认证: 不需要登录 -- 请求体: `application/json` -- 字段: `email`, `email_code`, `new_password`, `confirm_password` -- 实现: `reset_password`(`hertz_studio_django_auth/views/auth_views.py:259`) -- 请求示例: - ```http - POST /api/auth/password/reset/ - Content-Type: application/json - - {"email": "user@example.com", "email_code": "654321", "new_password": "N3wPass!", "confirm_password": "N3wPass!"} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "密码重置成功"} - ``` - -### (6)获取当前用户信息 -- 方法: `GET` -- 路径: `/api/auth/user/info/` -- 认证: 需要登录 -- 实现: `get_user_info`(`hertz_studio_django_auth/views/auth_views.py:310`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/auth/user/info/" -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": { - "user_id": 1, - "username": "admin", - "email": "admin@example.com", - "phone": "13800000000", - "real_name": "管理员", - "department_id": 2, - "department_name": "技术部", - "roles": [{"role_id": 1, "role_name": "管理员", "role_code": "admin"}], - "last_login_time": "2025-11-17T09:00:00Z" - } - } - ``` - -### (7)更新当前用户信息 -- 方法: `PUT` -- 路径: `/api/auth/user/info/update/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `email`, `phone`, `real_name`, `avatar`, `gender`, `birthday` -- 实现: `update_user_info`(`hertz_studio_django_auth/views/auth_views.py:339`) -- 请求示例: - ```http - PUT /api/auth/user/info/update/ - Authorization: Bearer - Content-Type: application/json - - {"real_name": "张三", "phone": "13800000002"} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "用户信息更新成功", "data": {"real_name": "张三", "phone": "13800000002"}} - ``` - -### (8)获取用户菜单 -- 方法: `GET` -- 路径: `/api/auth/user/menus/` -- 认证: 需要登录 -- 实现: `get_user_menus`(`hertz_studio_django_auth/views/auth_views.py:368`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/auth/user/menus/" -H "Authorization: Bearer " - ``` -- 返回示例(树形结构): - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": [ - {"menu_id": 1, "menu_name": "系统管理", "menu_code": "system", "menu_type": 1, "path": "/system", "children": [ - {"menu_id": 2, "menu_name": "用户管理", "menu_code": "system:user", "menu_type": 2, "path": "/system/users", "children": []} - ]} - ] - } - ``` - -### (9)发送邮箱验证码 -- 方法: `POST` -- 路径: `/api/auth/email/code/` -- 认证: 不需要登录 -- 请求体: `application/json` -- 字段: `email`, `code_type`(如 `register|reset_password`) -- 实现: `send_email_code`(`hertz_studio_django_auth/views/auth_views.py:441`) -- 请求示例: - ```http - POST /api/auth/email/code/ - Content-Type: application/json - - {"email": "user@example.com", "code_type": "register"} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "验证码发送成功,请查收邮件", "data": {"email": "user@example.com", "code_type": "register", "expires_in": 300}} - ``` - -### (10)刷新访问令牌 -- 方法: `POST` -- 路径: `/api/auth/token/refresh/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `refresh_token` -- 实现: `refresh_token`(`hertz_studio_django_auth/views/auth_views.py:544`) -- 请求示例: - ```http - POST /api/auth/token/refresh/ - Authorization: Bearer - Content-Type: application/json - - {"refresh_token": "eyJhbGci..."} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "token刷新成功", "data": {"access_token": "eyJhbGci..."}} - ``` - - -## 二、用户管理 - -- 前缀: `/api/` - -### (1)获取用户列表 -- 方法: `GET` -- 路径: `/api/users/` -- 权限: `system:user:list` -- 查询参数: `page`, `page_size`, `username`, `email`, `real_name`, `department_id`, `status` -- 实现: `user_list`(`hertz_studio_django_auth/views/management_views.py:38`) - -### (2)创建用户 -- 方法: `POST` -- 路径: `/api/users/create/` -- 权限: `system:user:add` -- 请求体: `UserManagementSerializer` 字段 -- 实现: `user_create`(`hertz_studio_django_auth/views/management_views.py:106`) - -### (3)获取用户详情 -- 方法: `GET` -- 路径: `/api/users/{user_id}/` -- 权限: `system:user:query` -- 实现: `user_detail`(`hertz_studio_django_auth/views/management_views.py:149`) - -### (4)更新用户 -- 方法: `PUT` -- 路径: `/api/users/{user_id}/update/` -- 权限: `system:user:edit` -- 请求体: `UserManagementSerializer` -- 实现: `user_update`(`hertz_studio_django_auth/views/management_views.py:190`) - -### (5)删除用户 -- 方法: `DELETE` -- 路径: `/api/users/{user_id}/delete/` -- 权限: `system:user:remove` -- 实现: `user_delete`(`hertz_studio_django_auth/views/management_views.py:236`) - -### (6)分配用户角色 -- 方法: `POST` -- 路径: `/api/users/assign-roles/` -- 权限: `system:user:role` -- 请求体: `user_id`, `role_ids: number[]` -- 实现: `user_assign_roles`(`hertz_studio_django_auth/views/management_views.py:279`) - - -## 三、角色管理 - -- 前缀: `/api/` - -### (1)获取角色列表 -- 方法: `GET` -- 路径: `/api/roles/` -- 权限: `system:role:list` -- 查询参数: `page`, `page_size`, `role_name`, `role_code`, `status` -- 实现: `role_list`(`hertz_studio_django_auth/views/management_views.py:328`) - -### (2)创建角色 -- 方法: `POST` -- 路径: `/api/roles/create/` -- 权限: `system:role:add` -- 请求体: `RoleManagementSerializer` -- 实现: `role_create`(`hertz_studio_django_auth/views/management_views.py:393`) - -### (3)获取角色详情 -- 方法: `GET` -- 路径: `/api/roles/{role_id}/` -- 权限: `system:role:query` -- 实现: `role_detail`(`hertz_studio_django_auth/views/management_views.py:437`) - -### (4)更新角色 -- 方法: `PUT` -- 路径: `/api/roles/{role_id}/update/` -- 权限: `system:role:edit` -- 请求体: `RoleManagementSerializer` -- 实现: `role_update`(`hertz_studio_django_auth/views/management_views.py:477`) - -### (5)删除角色 -- 方法: `DELETE` -- 路径: `/api/roles/{role_id}/delete/` -- 权限: `system:role:remove` -- 实现: `role_delete`(`hertz_studio_django_auth/views/management_views.py:527`) - -### (6)分配角色菜单 -- 方法: `POST` -- 路径: `/api/roles/assign-menus/` -- 权限: `system:role:menu` -- 请求体: `role_id`, `menu_ids: number[]` -- 实现: `role_assign_menus`(`hertz_studio_django_auth/views/management_views.py:579`) - -### (7)获取角色菜单 -- 方法: `GET` -- 路径: `/api/roles/{role_id}/menus/` -- 权限: `system:role:menu` -- 实现: `role_menus`(`hertz_studio_django_auth/views/management_views.py:631`) - - -## 四、菜单管理 - -- 前缀: `/api/` - -### (1)获取菜单列表 -- 方法: `GET` -- 路径: `/api/menus/` -- 权限: `system:menu:list` -- 实现: `menu_list`(`hertz_studio_django_auth/views/management_views.py:665`) - -### (2)获取菜单树 -- 方法: `GET` -- 路径: `/api/menus/tree/` -- 权限: `system:menu:list` -- 实现: `menu_tree`(`hertz_studio_django_auth/views/management_views.py:710`) - -### (3)创建菜单 -- 方法: `POST` -- 路径: `/api/menus/create/` -- 权限: `system:menu:add` -- 请求体: `MenuManagementSerializer` -- 实现: `menu_create`(`hertz_studio_django_auth/views/management_views.py:744`) - -### (4)获取菜单详情 -- 方法: `GET` -- 路径: `/api/menus/{menu_id}/` -- 权限: `system:menu:query` -- 实现: `menu_detail`(`hertz_studio_django_auth/views/management_views.py:787`) - -### (5)更新菜单 -- 方法: `PUT` -- 路径: `/api/menus/{menu_id}/update/` -- 权限: `system:menu:edit` -- 请求体: `MenuManagementSerializer` -- 实现: `menu_update`(`hertz_studio_django_auth/views/management_views.py:828`) - -### (6)删除菜单 -- 方法: `DELETE` -- 路径: `/api/menus/{menu_id}/delete/` -- 权限: `system:menu:remove` -- 实现: `menu_delete`(`hertz_studio_django_auth/views/management_views.py:878`) - - -## 五、部门管理 - -- 前缀: `/api/` - -### (1)获取部门列表 -- 方法: `GET` -- 路径: `/api/departments/` -- 权限: `system:dept:list` -- 实现: `department_list`(`hertz_studio_django_auth/views/management_views.py:921`) - -### (2)获取部门树 -- 方法: `GET` -- 路径: `/api/departments/tree/` -- 权限: `system:dept:list` -- 实现: `department_tree`(`hertz_studio_django_auth/views/management_views.py:963`) - -### (3)创建部门 -- 方法: `POST` -- 路径: `/api/departments/create/` -- 权限: `system:dept:add` -- 请求体: `DepartmentManagementSerializer` -- 实现: `department_create`(`hertz_studio_django_auth/views/management_views.py:997`) - -### (4)获取部门详情 -- 方法: `GET` -- 路径: `/api/departments/{dept_id}/` -- 权限: `system:dept:query` -- 实现: `department_detail`(`hertz_studio_django_auth/views/management_views.py:1040`) - -### (5)更新部门 -- 方法: `PUT` -- 路径: `/api/departments/{dept_id}/update/` -- 权限: `system:dept:edit` -- 请求体: `DepartmentManagementSerializer` -- 实现: `department_update`(`hertz_studio_django_auth/views/management_views.py:1081`) - -### (6)删除部门 -- 方法: `DELETE` -- 路径: `/api/departments/{dept_id}/delete/` -- 权限: `system:dept:remove` -- 实现: `department_delete`(`hertz_studio_django_auth/views/management_views.py:1131`) - - -## 六、错误响应示例 -- 未授权访问: - ```json - {"success": false, "code": 401, "message": "未提供认证令牌"} - ``` -- 权限不足: - ```json - {"success": false, "code": 403, "message": "缺少权限:system:user:list"} - ``` -- 参数验证失败: - ```json - {"success": false, "code": 422, "message": "参数验证失败", "errors": {"email": ["邮箱格式不正确"]}} - ``` \ No newline at end of file diff --git a/docs/API接口文档/通知模块接口文档.md b/docs/API接口文档/通知模块接口文档.md deleted file mode 100644 index 58bbcf4..0000000 --- a/docs/API接口文档/通知模块接口文档.md +++ /dev/null @@ -1,391 +0,0 @@ -# Hertz Studio Django 通知模块接口文档 - -- 基础路径: `/api/notice/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/notice/', include('hertz_studio_django_notice.urls'))` 挂载(`hertz_server_django/urls.py:19`)。 -- 认证说明: 所有接口均需要登录,需在请求头携带 `Authorization: Bearer `(`hertz_studio_django_auth/utils/decorators/auth_decorators.py:1`)。 - - -## 一、管理员通知接口 - -### (1)创建通知 -- 方法: `POST` -- 路径: `/api/notice/admin/create/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `title`, `content`, `notice_type`, `priority`, `is_top`, `publish_time`, `expire_time`, `attachment_url` -- 实现: `admin_create_notice`(`hertz_studio_django_notice/views/admin_views.py:39`) -- 请求示例: - ```http - POST /api/notice/admin/create/ - Authorization: Bearer - Content-Type: application/json - - { - "title": "系统维护通知", - "content": "将于周六晚间进行系统维护。", - "notice_type": 1, - "priority": 3, - "is_top": true, - "publish_time": "2025-11-18 09:00:00", - "expire_time": "2025-11-20 23:59:59", - "attachment_url": null - } - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "通知创建成功", "data": {"notice_id": 101}} - ``` - -### (2)更新通知 -- 方法: `PUT` -- 路径: `/api/notice/admin/update/{notice_id}/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `title`, `content`, `notice_type`, `priority`, `is_top`, `publish_time`, `expire_time`, `attachment_url`, `status` -- 实现: `admin_update_notice`(`hertz_studio_django_notice/views/admin_views.py:94`) -- 请求示例: - ```http - PUT /api/notice/admin/update/101/ - Authorization: Bearer - Content-Type: application/json - - {"priority": 4, "is_top": false} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "通知更新成功"} - ``` - -### (3)删除通知 -- 方法: `DELETE` -- 路径: `/api/notice/admin/delete/{notice_id}/` -- 认证: 需要登录 -- 实现: `admin_delete_notice`(`hertz_studio_django_notice/views/admin_views.py:146`) -- 请求示例: - ```bash - curl -X DELETE "http://localhost:8000/api/notice/admin/delete/101/" \ - -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "通知删除成功"} - ``` - -### (4)获取通知列表 -- 方法: `GET` -- 路径: `/api/notice/admin/list/` -- 认证: 需要登录 -- 查询参数: - - `page`, `page_size` - - `notice_type`, `status`, `priority`, `is_top`, `keyword` - - `start_date`, `end_date`(按发布时间范围筛选) -- 实现: `admin_get_notice_list`(`hertz_studio_django_notice/views/admin_views.py:184`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/notice/admin/list/?page=1&page_size=10&status=1&is_top=true" \ - -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取通知列表成功", - "data": { - "notices": [ - { - "notice_id": 101, - "title": "系统维护通知", - "notice_type": 1, - "notice_type_display": "系统通知", - "priority": 3, - "priority_display": "高", - "status": 1, - "status_display": "已发布", - "is_top": true, - "publish_time": "2025-11-18T09:00:00Z", - "expire_time": "2025-11-20T23:59:59Z", - "publisher_name": "管理员", - "view_count": 12, - "is_expired": false, - "read_count": 30, - "unread_count": 5, - "created_at": "2025-11-17T10:00:00Z", - "updated_at": "2025-11-17T10:00:00Z" - } - ], - "pagination": {"current_page": 1, "page_size": 10, "total_pages": 1, "total_count": 1, "has_next": false, "has_previous": false} - } - } - ``` - -### (5)获取通知详情 -- 方法: `GET` -- 路径: `/api/notice/admin/detail/{notice_id}/` -- 认证: 需要登录 -- 实现: `admin_get_notice_detail`(`hertz_studio_django_notice/views/admin_views.py:273`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/notice/admin/detail/101/" -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取通知详情成功", - "data": { - "notice_id": 101, - "title": "系统维护通知", - "content": "将于周六晚间进行系统维护。", - "notice_type": 1, - "notice_type_display": "系统通知", - "priority": 3, - "priority_display": "高", - "status": 1, - "status_display": "已发布", - "is_top": true, - "publish_time": "2025-11-18T09:00:00Z", - "expire_time": "2025-11-20T23:59:59Z", - "attachment_url": null, - "publisher_name": "管理员", - "publisher_username": "admin", - "view_count": 12, - "is_expired": false, - "read_count": 30, - "unread_count": 5, - "created_at": "2025-11-17T10:00:00Z", - "updated_at": "2025-11-17T10:00:00Z" - } - } - ``` - -### (6)发布通知 -- 方法: `POST` -- 路径: `/api/notice/admin/publish/{notice_id}/` -- 认证: 需要登录 -- 实现: `admin_publish_notice`(`hertz_studio_django_notice/views/admin_views.py:317`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/notice/admin/publish/101/" -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "通知发布成功"} - ``` - -### (7)撤回通知 -- 方法: `POST` -- 路径: `/api/notice/admin/withdraw/{notice_id}/` -- 认证: 需要登录 -- 实现: `admin_withdraw_notice`(`hertz_studio_django_notice/views/admin_views.py:374`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/notice/admin/withdraw/101/" -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "通知撤回成功"} - ``` - - -## 二、用户通知接口 - -### (1)获取通知列表 -- 方法: `GET` -- 路径: `/api/notice/user/list/` -- 认证: 需要登录 -- 查询参数: - - `page`, `page_size` - - `notice_type`, `is_read`, `is_starred`, `priority`, `keyword` - - `show_expired`: `true/false`,默认不显示过期通知 -- 实现: `user_get_notice_list`(`hertz_studio_django_notice/views/user_views.py:28`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/notice/user/list/?page=1&page_size=10&is_read=false&is_starred=true" \ - -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取通知列表成功", - "data": { - "notices": [ - { - "title": "系统维护通知", - "notice_type_display": "系统通知", - "priority_display": "高", - "is_top": true, - "publish_time": "2025-11-18T09:00:00Z", - "is_read": false, - "read_time": null, - "is_starred": true, - "starred_time": "2025-11-17T12:00:00Z", - "is_expired": false, - "created_at": "2025-11-17T10:00:00Z" - } - ], - "pagination": { - "current_page": 1, - "page_size": 10, - "total_pages": 1, - "total_count": 1, - "has_next": false, - "has_previous": false - }, - "statistics": {"total_count": 10, "unread_count": 2, "starred_count": 3} - } - } - ``` - -### (2)获取通知详情 -- 方法: `GET` -- 路径: `/api/notice/user/detail/{notice_id}/` -- 认证: 需要登录 -- 行为: 自动标记为已读并增加查看次数 -- 实现: `user_get_notice_detail`(`hertz_studio_django_notice/views/user_views.py:147`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/notice/user/detail/101/" -H "Authorization: Bearer " - ``` -- 返回示例(节选): - ```json - { - "success": true, - "code": 200, - "message": "获取通知详情成功", - "data": { - "title": "系统维护通知", - "content": "将于周六晚间进行系统维护。", - "notice_type_display": "系统通知", - "priority_display": "高", - "attachment_url": null, - "publish_time": "2025-11-18T09:00:00Z", - "expire_time": "2025-11-20T23:59:59Z", - "is_top": true, - "is_expired": false, - "publisher_name": "管理员", - "is_read": true, - "read_time": "2025-11-17T12:05:00Z", - "is_starred": true, - "starred_time": "2025-11-17T12:00:00Z" - } - } - ``` - -### (3)标记通知已读 -- 方法: `POST` -- 路径: `/api/notice/user/mark-read/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `notice_id` -- 实现: `user_mark_notice_read`(`hertz_studio_django_notice/views/user_views.py:214`) -- 请求示例: - ```http - POST /api/notice/user/mark-read/ - Authorization: Bearer - Content-Type: application/json - - {"notice_id": 101} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "标记已读成功"} - ``` - -### (4)批量标记已读 -- 方法: `POST` -- 路径: `/api/notice/user/batch-mark-read/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `notice_ids: number[]` -- 实现: `user_batch_mark_read`(`hertz_studio_django_notice/views/user_views.py:260`) -- 请求示例: - ```http - POST /api/notice/user/batch-mark-read/ - Authorization: Bearer - Content-Type: application/json - - {"notice_ids": [101, 102, 103]} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "批量标记已读成功"} - ``` - -### (5)标记全部已读 -- 方法: `POST` -- 路径: `/api/notice/user/mark-all-read/` -- 认证: 需要登录 -- 实现: `user_mark_all_read`(`hertz_studio_django_notice/views/user_views.py:317`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/notice/user/mark-all-read/" -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "标记全部已读成功,共标记3条通知", "data": {"updated_count": 3}} - ``` - -### (6)切换收藏状态 -- 方法: `POST` -- 路径: `/api/notice/user/toggle-star/` -- 认证: 需要登录 -- 请求体: `application/json` -- 字段: `notice_id`, `is_starred` -- 实现: `user_toggle_notice_star`(`hertz_studio_django_notice/views/user_views.py:401`) -- 请求示例: - ```http - POST /api/notice/user/toggle-star/ - Authorization: Bearer - Content-Type: application/json - - {"notice_id": 101, "is_starred": true} - ``` -- 返回示例: - ```json - {"success": true, "code": 200, "message": "收藏成功"} - ``` - -### (7)获取通知统计信息 -- 方法: `GET` -- 路径: `/api/notice/user/statistics/` -- 认证: 需要登录 -- 实现: `user_get_notice_statistics`(`hertz_studio_django_notice/views/user_views.py:417`) -- 请求示例: - ```bash - curl "http://localhost:8000/api/notice/user/statistics/" -H "Authorization: Bearer " - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "获取统计信息成功", - "data": { - "total_count": 10, - "unread_count": 2, - "read_count": 8, - "starred_count": 3, - "type_statistics": {"系统通知": 6, "公告通知": 3, "活动通知": 1, "维护通知": 0}, - "priority_statistics": {"低": 2, "中": 4, "高": 3, "紧急": 1} - } - } - ``` - - -## 三、备注 -- 列表与详情序列包含显示用枚举字段(如 `notice_type_display`, `priority_display`, `status_display`)。 -- 用户视图中 `user_get_notice_detail` 会自动将未读标记为已读并累加 `view_count`。 -- 管理员发布接口会为所有启用用户创建 `HertzUserNotice` 状态记录。 \ No newline at end of file diff --git a/docs/API接口文档/验证码模块接口文档.md b/docs/API接口文档/验证码模块接口文档.md deleted file mode 100644 index 72ba281..0000000 --- a/docs/API接口文档/验证码模块接口文档.md +++ /dev/null @@ -1,83 +0,0 @@ -# Hertz Studio Django 验证码模块接口文档 - -- 基础路径: `/api/captcha/` -- 统一响应: 使用 `HertzResponse`,结构如下(参考 `hertz_studio_django_utils/responses/HertzResponse.py`) - ```json - { - "success": true, - "code": 200, - "message": "操作成功", - "data": {} - } - ``` -- 路由挂载: 项目主路由中已通过 `path('api/captcha/', include('hertz_studio_django_captcha.urls'))` 挂载(`hertz_server_django/urls.py:28`)。 -- 认证说明: 接口允许匿名访问(`AllowAny`),不需要登录(`hertz_studio_django_captcha/api_views.py:18`, `hertz_studio_django_captcha/api_views.py:60`)。 -- 验证接口说明: 独立的验证码验证接口已删除,验证逻辑集成到具体业务接口(`hertz_studio_django_captcha/api_views.py:54`)。 - - -## 一、生成验证码 - -### 生成新的验证码 -- 方法: `POST` -- 路径: `/api/captcha/generate/` -- 认证: 不需要登录 -- 实现: `CaptchaGenerateAPIView.post`(`hertz_studio_django_captcha/api_views.py:38`) -- 请求示例: - ```bash - curl -X POST "http://localhost:8000/api/captcha/generate/" - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "验证码生成成功", - "data": { - "captcha_id": "c1a2b3c4-d5e6-7890-abcd-ef1234567890", - "image_data": "data:image/png;base64,iVBORw0KGgo...", - "expires_in": 300 - } - } - ``` - - -## 二、刷新验证码 - -### 刷新已有验证码,生成新图像 -- 方法: `POST` -- 路径: `/api/captcha/refresh/` -- 认证: 不需要登录 -- 请求体: `application/json` -- 字段: `captcha_id` 旧验证码ID -- 实现: `CaptchaRefreshAPIView.post`(`hertz_studio_django_captcha/api_views.py:84`) -- 请求示例: - ```http - POST /api/captcha/refresh/ - Content-Type: application/json - - {"captcha_id": "c1a2b3c4-d5e6-7890-abcd-ef1234567890"} - ``` -- 返回示例: - ```json - { - "success": true, - "code": 200, - "message": "验证码刷新成功", - "data": { - "captcha_id": "f2b3c4d5-e6f7-8901-bcde-0123456789ab", - "image_data": "data:image/png;base64,iVBORw0KGgo...", - "expires_in": 300 - } - } - ``` - - -## 三、错误响应示例 -- 参数不完整(刷新接口): - ```json - {"success": false, "code": 400, "message": "参数不完整"} - ``` -- 生成/刷新失败(异常信息): - ```json - {"success": false, "code": 500, "message": "验证码生成失败", "error": "<错误信息>"} - ``` \ No newline at end of file diff --git a/docs/img/ac87c1f6-a28b-4959-ae98-cab5e150c56b.png b/docs/img/ac87c1f6-a28b-4959-ae98-cab5e150c56b.png deleted file mode 100644 index d3927ef..0000000 Binary files a/docs/img/ac87c1f6-a28b-4959-ae98-cab5e150c56b.png and /dev/null differ diff --git a/docs/img/img_1.png b/docs/img/img_1.png deleted file mode 100644 index 4066932..0000000 Binary files a/docs/img/img_1.png and /dev/null differ diff --git a/docs/img/img_2.png b/docs/img/img_2.png deleted file mode 100644 index 626c471..0000000 Binary files a/docs/img/img_2.png and /dev/null differ diff --git a/docs/img/img_3.png b/docs/img/img_3.png deleted file mode 100644 index d3d0a40..0000000 Binary files a/docs/img/img_3.png and /dev/null differ diff --git a/docs/使用手册.md b/docs/使用手册.md deleted file mode 100644 index 4ce3fa3..0000000 --- a/docs/使用手册.md +++ /dev/null @@ -1,68 +0,0 @@ -# 使用手册 - -## 一、**环境要求** - -- `Python 3.10+`(建议 3.11/3.12) -- 操作系统:Windows(PowerShell) -- 可选:本地 `Redis` 服务(默认地址 `redis://127.0.0.1:6379`) - -## 二、**依赖安装** - -- 一次性使用镜像安装: - - 在项目根目录终端执行: - - `pip install -r requirements.txt -i https://hertz:hertz@hzpypi.hzsystems.cn/simple/` - -- 注:若提示依赖安装失败,请将机器码保存并联系管理员开通机器码注册 - -## **三、启动服务** - -- 通过脚本启动(支持端口参数): - - `python start_server.py --port 8000` -- 访问地址: - - `http://127.0.0.1:8000/` - - WebSocket:`ws://127.0.0.1:8000/ws/` -- 首次启动将自动执行: - - 扫描并注册新应用到 `INSTALLED_APPS` 与 `urls.py`(`start_server.py:173`、`start_server.py:98`)。 - - 执行 `makemigrations` 与 `migrate`(`start_server.py:1109`)。 - - 初始化超级管理员/部门/菜单/角色等(`start_server.py:877`)。 - - 创建菜单生成器工具 `generate_menu.py`(`start_server.py:780`)。 - - 启动 `daphne` 并开启热重启监听(`start_server.py:1016`、`start_server.py:1063`)。 - -## 四、默认账号 - -- 超级管理员: - - 用户名:`hertz` - - 密码:`hertz` -- 普通用户 - - 用户名:`demo` - - 密码:`123456` - -## 五、**常见配置说明** - -- CORS:通过 `.env` 配置 `CORS_ALLOWED_ORIGINS` 与 `CORS_ALLOW_ALL_ORIGINS` -- 静态与媒体: - - 静态:`STATICFILES_DIRS = [BASE_DIR / 'static']` - - 媒体:`MEDIA_ROOT = BASE_DIR / 'media'`(`hertz_server_django/settings.py:209-216`)。 -- WebSocket:使用 `channels` 与 `channels-redis`,层配置读取 `REDIS_URL`(`hertz_server_django/settings.py:302-309`)。 - -## **六、问题排查** - -- `daphne` 或 `watchdog` 未安装: - - 运行:`pip install daphne watchdog -i https://pypi.tuna.tsinghua.edu.cn/simple`(`start_server.py:1235-1248` 有依赖检查)。 -- Redis 未运行: - - 安装并启动 Redis,或调整 `REDIS_URL` 指向可用实例。 - -## **七、项目结构** - -- 核心配置:`hertz_server_django/settings.py`、`hertz_server_django/urls.py` -- 启动脚本:`start_server.py` -- 依赖清单:`requirements.txt` -- 静态资源:`static/`,媒体资源:`media/` -- 业务模块:`hertz_studio_django_*` 系列包(鉴权、日志、通知、监控、Wiki、AI、YOLO、代码生成等)。 - -## 八、**快速启动** - -- 安装依赖:`pip install -r requirements.txt -i https://hertz:hertz@hzpypi.hzsystems.cn/simple/` -- 启动服务:`python start_server.py --port 8000` \ No newline at end of file diff --git a/docs/开发规范.md b/docs/开发规范.md deleted file mode 100644 index 1ac6e3b..0000000 --- a/docs/开发规范.md +++ /dev/null @@ -1,158 +0,0 @@ -# 新功能开发规范 - -## 一、命名规范 -- APP 名称:`hertz_studio_django_xxx`,全小写,用下划线分隔 -- Python 包与模块:全小写,短名称,避免缩写不清晰 -- URL 命名空间:`app_name = 'hertz_studio_django_xxx'` -- 数据库表名:`Meta.db_table = 'hertz_xxx_model'`,避免与其他库冲突 -- 迁移文件命名:描述性动词+对象,如 `0003_add_field_to_model` - -## 二、项目结构与约定 -- 标准结构:`apps.py`、`models.py`、`serializers.py`、`views.py`、`urls.py`、`admin.py` -- 静态资源:`media//...` 放置同路径资源以覆盖库静态文件 -- 配置集中:在 `settings.py` 维护,使用前缀化大写变量(如 `YOLO_MODEL`) - -## 三、接口返回规范 -- 统一使用 `HertzResponse`(路径:`hertz_studio_django_utils/responses/HertzResponse.py`) -- 成功: - ```python - from hertz_studio_django_utils.responses.HertzResponse import HertzResponse - return HertzResponse.success(data={'id': 1}, message='操作成功') - ``` -- 失败: - ```python - return HertzResponse.fail(message='业务失败') - ``` -- 错误: - ```python - return HertzResponse.error(message='系统错误', error=str(e)) - ``` -- 验证失败: - ```python - return HertzResponse.validation_error(message='参数验证失败', errors=serializer.errors) - ``` -- 统一键:`success | code | message | data`,禁止返回非标准顶层结构 - -## 四、API 设计规范 -- 路径语义化:`/models/`、`/detections/`、`/alerts/` -- 方法约定:`GET` 查询、`POST` 创建/动作、`PUT/PATCH` 更新、`DELETE` 删除 -- 分页:请求参数 `page, page_size`;响应 `total, items` -- 过滤与排序:查询参数 `q, order_by, filters`;谨慎开放可排序字段 - -## 五、认证与授权 -- 强制认证:业务敏感接口使用登录态(装饰器或 DRF 权限类) -- 权限控制:按用户/组/角色配置;避免在视图中硬编码权限 -- 速率限制:对登录、验证码、检测等接口进行限流 - -## 六、日志与审计 -- 请求审计:记录请求方法、路径、用户、响应码、耗时 -- 业务事件:模型启用/删除、检测结果、告警变更记录 -- 脱敏:对密码、令牌、隐私字段进行统一脱敏 - -## 七、配置约定 -- 所有库配置集中在 `settings.py`,使用前缀化变量: - - AI:`AI_MODEL_PROVIDER`、`AI_DEFAULT_MODEL`、`AI_TIMEOUT` - - Auth:`AUTH_LOGIN_REDIRECT_URL`、`AUTH_ENABLE_OAUTH` - - Captcha:`CAPTCHA_TYPE`、`CAPTCHA_EXPIRE_SECONDS` - - Log:`LOG_LEVEL`、`LOG_SINKS`、`LOG_REDACT_FIELDS` - - Notice:`NOTICE_CHANNELS`、`NOTICE_RETRY` - - Monitor:`MONITOR_PROBES`、`MONITOR_ALERTS` - - Wiki:`WIKI_MARKDOWN`、`WIKI_SEARCH_BACKEND` - - YOLO:`YOLO_MODEL`、`YOLO_DEVICE`、`YOLO_CONF_THRESHOLD` - - Codegen:`CODEGEN_TEMPLATES_DIR`、`CODEGEN_OUTPUT_DIR` - -## 八、可扩展性(不改库源码) -- 视图子类化 + 路由覆盖:在项目中继承库视图并替换路由匹配 - ```python - # urls.py - from django.urls import path, include - from your_app.views import MyDetectView - urlpatterns = [ - path('yolo/detect/', MyDetectView.as_view(), name='detect'), - path('yolo/', include('hertz_studio_django_yolo.urls', namespace='hertz_studio_django_yolo')), - ] - ``` -- 猴子补丁:在 `AppConfig.ready()` 将库函数替换为自定义函数 - ```python - # apps.py - from django.apps import AppConfig - class YourAppConfig(AppConfig): - name = 'your_app' - def ready(self): - from hertz_studio_django_yolo import views as yviews - from your_app.views import my_yolo_detection - yviews.yolo_detection = my_yolo_detection - ``` -- Admin 重注册:`unregister` 后 `register` 自定义 `ModelAdmin` -- 信号连接:在 `ready()` 中连接库暴露的信号以扩展行为 - -## 九、示例:覆写 YOLO 检测返回值 -- 目标位置:`hertz_studio_django_yolo/views.py:586-603` -- 最小替换示例(路由覆盖): - ```python - from rest_framework.decorators import api_view, parser_classes - from rest_framework.parsers import MultiPartParser, FormParser - from django.contrib.auth.decorators import login_required - from hertz_studio_django_utils.responses.HertzResponse import HertzResponse - from hertz_studio_django_yolo.views import _perform_detection - from hertz_studio_django_yolo.models import YoloModel, DetectionRecord - import uuid, os, time, shutil - - @api_view(['POST']) - @parser_classes([MultiPartParser, FormParser]) - @login_required - def my_yolo_detection(request): - # 复用库的流程,省略若干步骤,仅演示返回体差异化 - serializer = hertz_studio_django_yolo.serializers.DetectionRequestSerializer(data=request.data) - if not serializer.is_valid(): - return HertzResponse.validation_error(message='参数验证失败', errors=serializer.errors) - uploaded_file = serializer.validated_data['file'] - yolo_model = YoloModel.get_enabled_model() - original_path = '...' # 省略:存储原始文件 - result_path, object_count, detected_categories, confidence_scores, avg_confidence = _perform_detection('...', yolo_model.model_path, 0.5, 'image', yolo_model) - processing_time = time.time() - time.time() - detection_record = DetectionRecord.objects.create( - original_file=original_path, - result_file='...', - detection_type='image', - model_name=f"{yolo_model.name} {yolo_model.version}", - model=yolo_model, - user=request.user, - user_name=request.user.username, - object_count=object_count, - detected_categories=detected_categories, - confidence_threshold=0.5, - confidence_scores=confidence_scores, - avg_confidence=avg_confidence, - processing_time=processing_time - ) - return HertzResponse.success( - data={ - 'id': detection_record.id, - 'file': { - 'result_url': detection_record.result_file.url, - 'original_url': detection_record.original_file.url - }, - 'stats': { - 'count': object_count, - 'categories': detected_categories, - 'scores': confidence_scores, - 'avg_score': round(avg_confidence, 4) if avg_confidence is not None else None, - 'time': round(processing_time, 2) - }, - 'model': { - 'name': yolo_model.name, - 'version': yolo_model.version, - 'threshold': 0.5 - }, - 'user': { - 'id': getattr(request.user, 'user_id', None), - 'name': request.user.username - } - }, - message='检测完成' - ) - ``` - - - diff --git a/docs/项目简介.md b/docs/项目简介.md deleted file mode 100644 index ffb2f2c..0000000 --- a/docs/项目简介.md +++ /dev/null @@ -1,53 +0,0 @@ -# Hertz Studio 后端 - -## **一、系统简介** - -- 统一后端服务,提供 `REST API` 与 `WebSocket`,面向 AI 工作室与通用后台场景。 -- 模块化设计,覆盖认证与权限、通知公告、日志、知识库、系统监控、AI 对话、代码生成、Sklearn 推理、YOLO 目标检测等。 -- 基于 `ASGI` 架构,使用 `Daphne` 运行;默认使用 `SQLite`,可切换 `MySQL`;缓存与消息通道使用 `Redis`。 -- 自动化启动脚本 `start_server.py` 支持数据库迁移与初始数据(菜单、角色、超级管理员)初始化,以及热重启文件监听。 - -## 二、**体验账户** - -- 管理员 - -​ 账号:hertz 密码:hertz - -- 普通用户 - - 账号:demo 密码:123456 - -## 三、**技术栈** - -- 后端框架:`Django 5`、`Django REST Framework`、`Channels` + `Daphne`。 -- 数据与缓存:`SQLite`(默认)/ `MySQL`(可选)、`Redis`(缓存、会话、Channel Layer)。 -- API 文档:`drf-spectacular` 自动生成,提供 Swagger 与 Redoc 页面。 -- 认证与安全:自定义 `AuthMiddleware` + `JWT`(`pyjwt`),`CORS` 支持。 -- AI / ML:`Ultralytics YOLO`、`OpenCV`、`NumPy`、`Scikit-learn`、`Joblib`,以及本地 `Ollama` 对话集成。 -- 工具与其他:`Mako` 模板(代码生成)、`Pillow`、`watchdog`、`psutil`、`GPUtil`。 - -## **四、功能** - -- 认证与权限(`hertz_studio_django_auth`) - - 用户注册/登录/登出、密码管理、用户信息维护。 - - `JWT` 发放与刷新,角色/菜单权限体系,接口权限由 `AuthMiddleware` 统一控制。 -- 图形验证码(`hertz_studio_django_captcha`) - - 可配置验证码生成与校验、尺寸/颜色/噪声等参数支持。 -- 通知公告(`hertz_studio_django_notice`) - - 公告 CRUD、状态管理,面向工作室信息发布。 -- 日志管理(`hertz_studio_django_log`) - - 操作日志采集与查询,支持接口级日志记录装饰器。 -- 知识库(`hertz_studio_django_wiki`) - - 文章与分类管理,面向知识内容沉淀与检索。 -- 系统监控(`hertz_studio_django_system_monitor`) - - CPU/内存/磁盘/GPU 指标采集与展示,基于 `psutil`/`GPUtil`。 -- AI 对话(`hertz_studio_django_ai`) - - 对接本地 `Ollama`,提供对话接口与 `WebSocket` 推送能力。 -- 代码生成(`hertz_studio_django_codegen`) - - 基于 `Mako` 的 Django 代码与菜单生成器,支持 CLI 生成并同步权限。 -- Sklearn/PyTorch 推理(`hertz_studio_django_sklearn`) - - 模型上传、元数据解析(特征/输入输出模式)、统一预测接口,支持 `predict_proba`。 -- YOLO 目标检测(`hertz_studio_django_yolo`) - - 模型管理与启用切换、检测接口、结果中文别名标注与图像绘制;默认模型位于 `static/models/yolov12/weights/best.pt`。 -- Demo 与首页(`hertz_demo`) - - 示例页面(验证码、邮件、WebSocket)与首页模板 `templates/index.html`。 \ No newline at end of file diff --git a/generate_menu.py b/generate_menu.py deleted file mode 100644 index a55e725..0000000 --- a/generate_menu.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python -""" -菜单生成器命令行工具 -用于快速生成菜单配置和权限同步 -""" - -import os -import sys -import argparse -import django -from pathlib import Path - -# 添加项目路径 -project_root = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, project_root) - -# 设置Django环境 -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') -django.setup() - -from hertz_studio_django_utils.code_generator.menu_generator import MenuGenerator - - -def generate_crud_menu(args): - """生成CRUD菜单""" - generator = MenuGenerator() - - operations = args.operations.split(',') if args.operations else ['list', 'create', 'update', 'delete'] - - menus = generator.generate_menu_config( - module_name=args.module_name, - model_name=args.model_name, - operations=operations, - parent_code=args.parent_code, - menu_prefix=args.prefix, - sort_order=args.sort_order, - icon=args.icon - ) - - # 保存到待同步文件 - pending_file = os.path.join(project_root, 'pending_menus.py') - with open(pending_file, 'w', encoding='utf-8') as f: - f.write('# 待同步的菜单配置\n') - f.write('pending_menus = [\n') - for menu in menus: - f.write(' {\n') - for key, value in menu.items(): - if isinstance(value, str): - f.write(f" '{key}': '{value}',\n") - elif value is None: - f.write(f" '{key}': None,\n") - else: - f.write(f" '{key}': {value},\n") - f.write(' },\n') - f.write(']\n') - - print(f"已生成 {len(menus)} 个菜单配置,保存到 pending_menus.py") - print("请重启服务器以同步菜单到数据库") - - -def menu_generator_main(): - parser = argparse.ArgumentParser(description='菜单生成器') - subparsers = parser.add_subparsers(dest='command', help='可用命令') - - # CRUD菜单生成命令 - crud_parser = subparsers.add_parser('crud', help='生成CRUD菜单') - crud_parser.add_argument('module_name', help='模块名称(中文)') - crud_parser.add_argument('model_name', help='模型名称(英文)') - crud_parser.add_argument('--parent-code', default='system', help='父级菜单代码') - crud_parser.add_argument('--prefix', default='system', help='菜单前缀') - crud_parser.add_argument('--operations', help='操作列表(逗号分隔)') - crud_parser.add_argument('--sort-order', type=int, default=1, help='排序') - crud_parser.add_argument('--icon', help='图标') - - args = parser.parse_args() - - if args.command == 'crud': - generate_crud_menu(args) - else: - parser.print_help() - - -if __name__ == "__main__": - menu_generator_main() diff --git a/get_machine_code.bat b/get_machine_code.bat deleted file mode 100644 index 93fd3eb..0000000 --- a/get_machine_code.bat +++ /dev/null @@ -1,13 +0,0 @@ -@echo off -chcp 65001 >nul -echo ================================ -echo Hertz Django Get Machine Code -echo ================================ - -python get_machine_code.py - -echo. -echo ================================ -echo Please send the machine code to the after-sales personnel -echo ================================ -pause >nul \ No newline at end of file diff --git a/get_machine_code.py b/get_machine_code.py deleted file mode 100644 index b4b2d42..0000000 --- a/get_machine_code.py +++ /dev/null @@ -1,18 +0,0 @@ -import platform -import uuid -import hashlib - - -def get_machine_id() -> str: - """生成机器码。 - - 根据当前系统信息(平台、架构、MAC地址)生成唯一机器码, - 使用 SHA256 取前16位并转大写,前缀为 HERTZ_STUDIO_。 - """ - system_info = f"{platform.platform()}-{platform.machine()}-{uuid.getnode()}" - return 'HERTZ_STUDIO_' + hashlib.sha256(system_info.encode()).hexdigest()[:16].upper() - - -if __name__ == "__main__": - machine_code = get_machine_id() - print(f"您的机器码是: {machine_code}") \ No newline at end of file diff --git a/get_tokens.py b/get_tokens.py deleted file mode 100644 index fd0d87a..0000000 --- a/get_tokens.py +++ /dev/null @@ -1,63 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -获取用户JWT token的脚本 -""" -import os -import sys -import django - -# 设置Django环境 -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') -django.setup() - -from hertz_studio_django_auth.models import HertzUser -from hertz_studio_django_auth.utils.auth.token_utils import TokenUtils - -def get_user_tokens(): - """获取用户token""" - try: - # 获取普通用户hertz - user = HertzUser.objects.get(username='demo') - user_roles = user.roles.all() - print(f'找到用户: {user.username}, 角色: {[role.role_code for role in user_roles]}') - - # 生成token - user_data = { - 'user_id': str(user.user_id), - 'username': user.username, - 'email': user.email, - 'roles': [role.role_code for role in user_roles], - 'permissions': [] - } - token_data = TokenUtils.generate_token(user_data) - print(f'普通用户token: {token_data}') - print() - - # 获取管理员用户hertz - admin_user = HertzUser.objects.get(username='hertz') - admin_roles = admin_user.roles.all() - print(f'找到管理员: {admin_user.username}, 角色: {[role.role_code for role in admin_roles]}') - - # 生成管理员token - admin_data = { - 'user_id': str(admin_user.user_id), - 'username': admin_user.username, - 'email': admin_user.email, - 'roles': [role.role_code for role in admin_roles], - 'permissions': [] - } - admin_token_data = TokenUtils.generate_token(admin_data) - print(f'管理员token: {admin_token_data}') - - return { - 'user_token': token_data, - 'admin_token': admin_token_data - } - - except Exception as e: - print(f'错误: {e}') - return None - -if __name__ == '__main__': - get_user_tokens() \ No newline at end of file diff --git a/hertz.txt b/hertz.txt deleted file mode 100644 index af74862..0000000 --- a/hertz.txt +++ /dev/null @@ -1,11 +0,0 @@ -# ================hertz官方库================ -hertz-studio-django-ai -hertz-studio-django-auth -hertz-studio-django-captcha -hertz-studio-django-kb -hertz-studio-django-log -hertz-studio-django-notice -hertz-studio-django-system-monitor -hertz-studio-django-wiki -hertz-studio-django-yolo -hertz-studio-django-yolo-train \ No newline at end of file diff --git a/hertz_demo/README.md b/hertz_demo/README.md deleted file mode 100644 index ce9cee0..0000000 --- a/hertz_demo/README.md +++ /dev/null @@ -1,274 +0,0 @@ -# Hertz Demo 演示模块 - -## 📋 模块概述 - -Hertz Demo 模块是一个功能演示和测试模块,提供了完整的示例代码和交互式演示页面,帮助开发者快速了解和使用 Hertz Server Django 框架的各项功能特性。 - -## ✨ 功能特性 - -- **验证码演示**: 展示多种验证码类型的生成、刷新和验证功能 -- **邮件系统演示**: 提供邮件模板预览和发送测试功能 -- **WebSocket演示**: 实时通信功能演示和测试 -- **交互式界面**: 美观的Web界面,支持实时操作和反馈 -- **完整示例代码**: 提供可直接参考的实现代码 - -## 📁 模块结构 - -``` -hertz_demo/ -├── __init__.py # 模块初始化 -├── apps.py # Django应用配置 -├── models.py # 数据模型(预留) -├── views.py # 视图函数和业务逻辑 -├── urls.py # URL路由配置 -├── tests.py # 单元测试 -├── consumers.py # WebSocket消费者 -├── routing.py # WebSocket路由 -└── templates/ # 模板文件 - ├── captcha_demo.html # 验证码演示页面 - ├── email_demo.html # 邮件系统演示页面 - └── websocket_demo.html # WebSocket演示页面 -``` - -## 🎯 核心功能详解 - -### 1. 验证码演示功能 - -验证码演示页面提供三种验证码类型: -- **随机字符验证码**: 随机生成的字母数字组合 -- **数学运算验证码**: 简单的数学计算验证 -- **单词验证码**: 英文单词验证 - -**主要功能**: -- 验证码实时生成和刷新 -- 前端Ajax验证 -- 后端表单验证 -- 验证码类型切换 - -### 2. 邮件系统演示功能 - -邮件演示页面提供多种邮件模板: -- **欢迎邮件**: 用户注册欢迎邮件模板 -- **系统通知**: 系统消息通知模板 -- **邮箱验证**: 邮箱验证邮件模板 -- **自定义邮件**: 支持自定义主题和内容 - -**主要功能**: -- 邮件模板实时预览 -- 邮件发送测试 -- 收件人邮箱验证 -- 发送状态反馈 - -### 3. WebSocket演示功能 - -WebSocket演示页面提供实时通信功能: -- **连接状态管理**: 显示WebSocket连接状态 -- **消息发送接收**: 实时消息通信 -- **广播功能**: 消息广播演示 -- **错误处理**: 连接异常处理 - -## 🚀 API接口 - -### 演示页面路由 - -| 路由 | 方法 | 描述 | -|------|------|------| -| `/demo/captcha/` | GET | 验证码演示页面 | -| `/demo/email/` | GET | 邮件系统演示页面 | -| `/demo/websocket/` | GET | WebSocket演示页面 | -| `/websocket/test/` | GET | WebSocket测试页面 | - -### Ajax接口 - -**验证码相关**: -- `POST /demo/captcha/` (Ajax): 验证码刷新和验证 -- 请求体: `{"action": "refresh/verify", "captcha_id": "...", "user_input": "..."}` - -**邮件发送**: -- `POST /demo/email/` (Ajax): 发送演示邮件 -- 请求体: 邮件类型、收件人邮箱、自定义内容等 - -## ⚙️ 配置参数 - -### 邮件配置(settings.py) -```python -# 邮件服务器配置 -EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' -EMAIL_HOST = 'smtp.gmail.com' -EMAIL_PORT = 587 -EMAIL_USE_TLS = True -EMAIL_HOST_USER = 'your-email@gmail.com' -EMAIL_HOST_PASSWORD = 'your-app-password' -DEFAULT_FROM_EMAIL = 'noreply@yourdomain.com' -``` - -### WebSocket配置 -```python -# ASGI配置 -ASGI_APPLICATION = 'hertz_server_django.asgi.application' - -# Channel layers配置 -CHANNEL_LAYERS = { - 'default': { - 'BACKEND': 'channels_redis.core.RedisChannelLayer', - 'CONFIG': { - 'hosts': [('127.0.0.1', 6379)], - }, - }, -} -``` - -## 🛠️ 快速开始 - -### 1. 访问演示页面 - -启动开发服务器后,访问以下URL: - -```bash -# 验证码演示 -http://localhost:8000/demo/captcha/ - -# 邮件系统演示 -http://localhost:8000/demo/email/ - -# WebSocket演示 -http://localhost:8000/demo/websocket/ -``` - -### 2. 测试验证码功能 - -1. 打开验证码演示页面 -2. 选择验证码类型 -3. 点击验证码图片可刷新 -4. 输入验证码进行验证 -5. 观察验证结果反馈 - -### 3. 测试邮件功能 - -1. 打开邮件演示页面 -2. 选择邮件模板类型 -3. 输入收件人邮箱 -4. 点击发送测试邮件 -5. 查看发送状态 - -### 4. 测试WebSocket功能 - -1. 打开WebSocket演示页面 -2. 点击"连接"按钮建立连接 -3. 在输入框中发送消息 -4. 观察消息接收和广播 -5. 测试断开重连功能 - -## 🔧 高级用法 - -### 自定义邮件模板 - -在 `views.py` 中的 `generate_email_content` 函数中添加新的邮件模板: - -```python -def generate_email_content(email_type, recipient_name, custom_subject='', custom_message=''): - email_templates = { - 'your_template': { - 'subject': '您的邮件主题', - 'html_template': ''' - - - - ''' - } - } - # ... -``` - -### 扩展验证码类型 - -在验证码演示中扩展新的验证码类型: - -```python -# 在 captcha_demo 函数中添加新的验证码类型 -captcha_types = { - 'random_char': '随机字符验证码', - 'math': '数学运算验证码', - 'word': '单词验证码', - 'new_type': '您的新验证码类型' # 新增类型 -} -``` - -### WebSocket消息处理 - -在 `consumers.py` 中扩展WebSocket消息处理逻辑: - -```python -class DemoConsumer(WebsocketConsumer): - async def receive(self, text_data): - data = json.loads(text_data) - message_type = data.get('type') - - if message_type == 'custom_message': - # 处理自定义消息类型 - await self.handle_custom_message(data) -``` - -## 🧪 测试 - -### 运行单元测试 - -```bash -python manage.py test hertz_demo -``` - -### 测试覆盖范围 - -- 验证码功能测试 -- 邮件发送测试 -- WebSocket连接测试 -- 页面渲染测试 - -## 🔒 安全考虑 - -### 验证码安全 -- 验证码有效期限制 -- 验证次数限制 -- 防止暴力破解 - -### 邮件安全 -- 收件人邮箱验证 -- 发送频率限制 -- 防止邮件滥用 - -### WebSocket安全 -- 连接认证 -- 消息内容过滤 -- 防止DDoS攻击 - -## ❓ 常见问题 - -### Q: 邮件发送失败怎么办? -A: 检查邮件服务器配置,确保SMTP设置正确,邮箱密码为应用专用密码。 - -### Q: WebSocket连接失败怎么办? -A: 检查Redis服务是否运行,确保CHANNEL_LAYERS配置正确。 - -### Q: 验证码验证总是失败? -A: 检查验证码存储后端(Redis)是否正常运行。 - -### Q: 如何添加新的演示功能? -A: 在views.py中添加新的视图函数,在urls.py中配置路由,在templates中添加模板文件。 - -## 📝 更新日志 - -### v1.0.0 (2024-01-01) -- 初始版本发布 -- 包含验证码、邮件、WebSocket演示功能 -- 提供完整的示例代码和文档 - -## 🔗 相关链接 - -- [🏠 返回主项目](../README.md) - Hertz Server Django 主项目文档 -- [🔐 认证授权模块](../hertz_studio_django_auth/README.md) - 用户管理和权限控制 -- [🛠️ 工具类模块](../hertz_studio_django_utils/README.md) - 加密、邮件和验证工具 -- [📋 代码风格指南](../docs/CODING_STYLE.md) - 开发规范和最佳实践 - ---- - -💡 **提示**: 此模块主要用于功能演示和学习参考,生产环境请根据实际需求进行适当调整和优化。 \ No newline at end of file diff --git a/hertz_demo/__init__.py b/hertz_demo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/hertz_demo/apps.py b/hertz_demo/apps.py deleted file mode 100644 index 3599ca4..0000000 --- a/hertz_demo/apps.py +++ /dev/null @@ -1,6 +0,0 @@ -from django.apps import AppConfig - - -class DemoConfig(AppConfig): - default_auto_field = 'django.db.models.BigAutoField' - name = 'hertz_demo' diff --git a/hertz_demo/consumers.py b/hertz_demo/consumers.py deleted file mode 100644 index 456d637..0000000 --- a/hertz_demo/consumers.py +++ /dev/null @@ -1,120 +0,0 @@ -import json -from datetime import datetime -from channels.generic.websocket import AsyncWebsocketConsumer -from channels.db import database_sync_to_async - - -class ChatConsumer(AsyncWebsocketConsumer): - async def connect(self): - self.room_name = self.scope['url_route']['kwargs']['room_name'] - self.room_group_name = f'chat_{self.room_name}' - - # Join room group - await self.channel_layer.group_add( - self.room_group_name, - self.channel_name - ) - - await self.accept() - - async def disconnect(self, close_code): - # Leave room group - await self.channel_layer.group_discard( - self.room_group_name, - self.channel_name - ) - - # Receive message from WebSocket - async def receive(self, text_data): - text_data_json = json.loads(text_data) - message_type = text_data_json.get('type', 'chat_message') - message = text_data_json.get('message', '') - username = text_data_json.get('username', 'Anonymous') - - # Send message to room group - await self.channel_layer.group_send( - self.room_group_name, - { - 'type': message_type, - 'message': message, - 'username': username, - 'timestamp': datetime.now().strftime('%H:%M:%S') - } - ) - - # Receive message from room group - async def chat_message(self, event): - message = event['message'] - username = event['username'] - timestamp = event['timestamp'] - - # Send message to WebSocket - await self.send(text_data=json.dumps({ - 'type': 'chat_message', - 'message': message, - 'username': username, - 'timestamp': timestamp - })) - - async def user_join(self, event): - username = event['username'] - timestamp = event['timestamp'] - - await self.send(text_data=json.dumps({ - 'type': 'user_notification', - 'message': f'{username} 加入了聊天室', - 'username': username, - 'timestamp': timestamp - })) - - async def user_leave(self, event): - username = event['username'] - timestamp = event['timestamp'] - - await self.send(text_data=json.dumps({ - 'type': 'user_notification', - 'message': f'{username} 离开了聊天室', - 'username': username, - 'timestamp': timestamp - })) - - -class EchoConsumer(AsyncWebsocketConsumer): - async def connect(self): - await self.accept() - - async def disconnect(self, close_code): - pass - - async def receive(self, text_data): - try: - # 解析接收到的JSON数据 - data = json.loads(text_data) - message = data.get('message', '').strip() - - if message: - # 返回回声消息 - response = { - 'type': 'echo_message', - 'original_message': message, - 'echo_message': f'回声: {message}', - 'timestamp': datetime.now().strftime('%H:%M:%S') - } - else: - # 如果消息为空 - response = { - 'type': 'error', - 'message': '消息不能为空', - 'timestamp': datetime.now().strftime('%H:%M:%S') - } - - await self.send(text_data=json.dumps(response, ensure_ascii=False)) - - except json.JSONDecodeError: - # JSON解析错误 - error_response = { - 'type': 'error', - 'message': '无效的JSON格式', - 'timestamp': datetime.now().strftime('%H:%M:%S') - } - await self.send(text_data=json.dumps(error_response, ensure_ascii=False)) \ No newline at end of file diff --git a/hertz_demo/migrations/__init__.py b/hertz_demo/migrations/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/hertz_demo/models.py b/hertz_demo/models.py deleted file mode 100644 index 71a8362..0000000 --- a/hertz_demo/models.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.db import models - -# Create your models here. diff --git a/hertz_demo/routing.py b/hertz_demo/routing.py deleted file mode 100644 index c1c2ded..0000000 --- a/hertz_demo/routing.py +++ /dev/null @@ -1,7 +0,0 @@ -from django.urls import re_path -from . import consumers - -websocket_urlpatterns = [ - re_path(r'ws/chat/(?P\w+)/$', consumers.ChatConsumer.as_asgi()), - re_path(r'ws/echo/$', consumers.EchoConsumer.as_asgi()), -] \ No newline at end of file diff --git a/hertz_demo/templates/captcha_demo.html b/hertz_demo/templates/captcha_demo.html deleted file mode 100644 index ac22ad0..0000000 --- a/hertz_demo/templates/captcha_demo.html +++ /dev/null @@ -1,499 +0,0 @@ - - - - - - Hertz验证码演示 - Hertz Server Django - - - -
-
-

🔐 Hertz验证码演示

-

{{ demo_description }}

-
- -
- -
-
-

🎯 Hertz验证码功能特性

-

自定义验证码系统,支持Redis缓存

-
-
-
-
    -
  • 🔤 随机字符验证码 - 生成随机字母数字组合
  • -
  • 🎨 自定义样式配置 - 支持颜色、字体、噪声等设置
  • -
  • ⚡ Ajax刷新功能 - 无需刷新页面
  • -
  • 💾 Redis缓存 - 高性能数据存储
  • -
  • ⏰ 超时自动失效 - 可配置过期时间
  • -
  • 🔧 灵活配置 - 通过settings.py进行配置
  • -
-
- -
-

配置信息

-

• 验证码长度: 可通过HERTZ_CAPTCHA_LENGTH配置

-

• 图片尺寸: 可通过HERTZ_CAPTCHA_WIDTH/HEIGHT配置

-

• 过期时间: 可通过HERTZ_CAPTCHA_TIMEOUT配置

-

• Redis前缀: 可通过HERTZ_CAPTCHA_REDIS_KEY_PREFIX配置

-
-
-
- - -
-
-

🔒 Hertz验证码测试

-

输入验证码进行功能测试

-
-
- {% if messages %} - {% for message in messages %} -
- {% if message.tags == 'success' %}✅{% elif message.tags == 'error' %}❌{% endif %} {{ message }} -
- {% endfor %} - {% endif %} - -
-

📋 Hertz验证码说明

-

• 随机字符验证码:生成随机字母和数字组合

-

• 特点:自定义样式,支持噪声干扰,Redis缓存存储

-

• 功能:支持Ajax刷新,自动过期失效

-
- -
- {% csrf_token %} - -
- -
- 验证码 - -
- - -
- - -
- -
-

💡 使用提示

-

• 点击验证码图片可以刷新

-

• 验证码不区分大小写

-

• 验证码有效期为5分钟

-
-
-
-
- - -
- - - - \ No newline at end of file diff --git a/hertz_demo/templates/email_demo.html b/hertz_demo/templates/email_demo.html deleted file mode 100644 index d25630d..0000000 --- a/hertz_demo/templates/email_demo.html +++ /dev/null @@ -1,520 +0,0 @@ - - - - - - 邮件系统演示 - Hertz Server Django - - - -
- - -
-

📧 邮件系统演示

-

体验Django邮件发送功能,支持多种邮件类型和模板

-
- -
-
-

📮 邮件发送功能

-

本演示展示了Django邮件系统的核心功能:

-
    -
  • 支持HTML和纯文本邮件
  • -
  • 多种邮件模板类型
  • -
  • SMTP配置和发送状态
  • -
  • 邮件预览和验证
  • -
- - -
- -
-

✉️ 发送邮件测试

- -
- -
- {% csrf_token %} - - -
- - -
- -
- - -
- - - -
- - -
- - - -
- -
-
-

正在发送邮件,请稍候...

-
-
-
- - -
- - - - \ No newline at end of file diff --git a/hertz_demo/templates/websocket_demo.html b/hertz_demo/templates/websocket_demo.html deleted file mode 100644 index c8c4638..0000000 --- a/hertz_demo/templates/websocket_demo.html +++ /dev/null @@ -1,556 +0,0 @@ - - - - - - WebSocket演示 - Hertz Server Django - - - -
-
-

🌐 WebSocket演示

-

实时通信功能展示 - 支持聊天室和回声测试

-
- -
- -
-
-

🔄 回声测试

-

发送消息,服务器会回声返回

-
-
-
未连接
-
-
- - -
-
- - - -
-
-
- - -
-
-

💬 聊天室

-

多用户实时聊天功能

-
-
-
未连接
-
- - - - -
-
-
- - - -
-
-
-
- - -
- - - - \ No newline at end of file diff --git a/hertz_demo/tests.py b/hertz_demo/tests.py deleted file mode 100644 index 7ce503c..0000000 --- a/hertz_demo/tests.py +++ /dev/null @@ -1,3 +0,0 @@ -from django.test import TestCase - -# Create your tests here. diff --git a/hertz_demo/urls.py b/hertz_demo/urls.py deleted file mode 100644 index 4d5147a..0000000 --- a/hertz_demo/urls.py +++ /dev/null @@ -1,11 +0,0 @@ -from django.urls import path -from . import views - -app_name = 'hertz_demo' - -urlpatterns = [ - path('demo/captcha/', views.captcha_demo, name='captcha_demo'), - path('demo/websocket/', views.websocket_demo, name='websocket_demo'), - path('websocket/test/', views.websocket_test, name='websocket_test'), - path('demo/email/', views.email_demo, name='email_demo'), -] \ No newline at end of file diff --git a/hertz_demo/views.py b/hertz_demo/views.py deleted file mode 100644 index 2af6d96..0000000 --- a/hertz_demo/views.py +++ /dev/null @@ -1,331 +0,0 @@ -from django.shortcuts import render, redirect -from django.http import JsonResponse, HttpResponse -from django.core.mail import send_mail, EmailMultiAlternatives -from django.template.loader import render_to_string -from django.utils.html import strip_tags -from django.views.decorators.csrf import csrf_exempt -from django.views.decorators.http import require_http_methods -from django.contrib import messages -from django import forms -from hertz_studio_django_captcha.captcha_generator import HertzCaptchaGenerator -import json -from django.conf import settings -import random -import string - -class HertzCaptchaForm(forms.Form): - """Hertz验证码表单""" - captcha_input = forms.CharField( - max_length=10, - widget=forms.TextInput(attrs={ - 'placeholder': '请输入验证码', - 'class': 'form-control', - 'autocomplete': 'off' - }), - label='验证码' - ) - captcha_id = forms.CharField(widget=forms.HiddenInput(), required=False) - -def captcha_demo(request): - """ - 验证码演示页面 - 展示多种验证码功能的使用方法 - """ - # 获取请求的验证码类型 - captcha_type = request.GET.get('type', 'random_char') - - # 初始化验证码生成器 - captcha_generator = HertzCaptchaGenerator() - - if request.method == 'POST': - # 检查是否是Ajax请求 - if request.headers.get('X-Requested-With') == 'XMLHttpRequest': - try: - data = json.loads(request.body) - action = data.get('action') - - if action == 'refresh': - # 刷新验证码 - captcha_data = captcha_generator.generate_captcha() - return JsonResponse({ - 'success': True, - 'data': captcha_data - }) - elif action == 'verify': - # 验证验证码 - captcha_id = data.get('captcha_id', '') - user_input = data.get('user_input', '') - - is_valid = captcha_generator.verify_captcha(captcha_id, user_input) - - if is_valid: - return JsonResponse({ - 'success': True, - 'valid': True, - 'message': f'验证成功!验证码类型: {captcha_type}' - }) - else: - return JsonResponse({ - 'success': True, - 'valid': False, - 'message': '验证码错误,请重新输入' - }) - except json.JSONDecodeError: - return JsonResponse({ - 'success': False, - 'error': '请求数据格式错误' - }) - else: - # 普通表单提交处理 - form = HertzCaptchaForm(request.POST) - username = request.POST.get('username', '') - captcha_id = request.POST.get('captcha_id', '') - captcha_input = request.POST.get('captcha_input', '') - - # 验证验证码 - is_valid = captcha_generator.verify_captcha(captcha_id, captcha_input) - - if is_valid and username: - # 生成新的验证码用于显示 - initial_captcha = captcha_generator.generate_captcha() - return render(request, 'captcha_demo.html', { - 'form': HertzCaptchaForm(), - 'success_message': f'验证成功!用户名: {username},验证码类型: {captcha_type}', - 'captcha_unavailable': False, - 'current_type': captcha_type, - 'initial_captcha': initial_captcha, - 'captcha_types': { - 'random_char': '随机字符验证码', - 'math': '数学运算验证码', - 'word': '单词验证码' - } - }) - - # GET请求或表单验证失败时,生成初始验证码 - form = HertzCaptchaForm() - initial_captcha = captcha_generator.generate_captcha() - - return render(request, 'captcha_demo.html', { - 'form': form, - 'captcha_unavailable': False, - 'current_type': captcha_type, - 'initial_captcha': initial_captcha, - 'captcha_types': { - 'random_char': '随机字符验证码', - 'math': '数学运算验证码', - 'word': '单词验证码' - } - }) - -def websocket_demo(request): - """WebSocket演示页面""" - return render(request, 'websocket_demo.html') - -def websocket_test(request): - """ - WebSocket简单测试页面 - """ - return render(request, 'websocket_test.html') - -# 测试热重启功能 - 添加注释触发文件变化 - -def email_demo(request): - """邮件系统演示页面""" - if request.method == 'GET': - return render(request, 'email_demo.html') - - elif request.method == 'POST': - try: - # 获取表单数据 - email_type = request.POST.get('email_type', 'welcome') - recipient_email = request.POST.get('recipient_email') - recipient_name = request.POST.get('recipient_name', '用户') - custom_subject = request.POST.get('subject', '') - custom_message = request.POST.get('message', '') - - if not recipient_email: - return JsonResponse({ - 'success': False, - 'message': '请输入收件人邮箱地址' - }) - - # 根据邮件类型生成内容 - email_content = generate_email_content(email_type, recipient_name, custom_subject, custom_message) - - # 发送邮件 - success = send_demo_email( - recipient_email=recipient_email, - subject=email_content['subject'], - html_content=email_content['html_content'], - text_content=email_content['text_content'] - ) - - if success: - return JsonResponse({ - 'success': True, - 'message': f'邮件已成功发送到 {recipient_email}' - }) - else: - return JsonResponse({ - 'success': False, - 'message': '邮件发送失败,请检查邮件配置' - }) - - except Exception as e: - return JsonResponse({ - 'success': False, - 'message': f'发送失败:{str(e)}' - }) - -def generate_email_content(email_type, recipient_name, custom_subject='', custom_message=''): - """根据邮件类型生成邮件内容""" - - email_templates = { - 'welcome': { - 'subject': '🎉 欢迎加入 Hertz Server Django!', - 'html_template': f''' - - -
-
-

🎉 欢迎加入我们!

-
-
-

亲爱的 {recipient_name}

-

欢迎您注册成为我们的用户!我们很高兴您能加入我们的大家庭。

-

在这里,您可以享受到:

-
    -
  • 🔐 安全的验证码系统
  • -
  • 🌐 实时WebSocket通信
  • -
  • 📧 完善的邮件服务
  • -
  • 📚 详细的API文档
  • -
-

如果您有任何问题,请随时联系我们。

- -

祝您使用愉快!

-
-

此致
Hertz Server Django 团队

-
-
- - - ''' - }, - 'notification': { - 'subject': '🔔 系统通知 - Hertz Server Django', - 'html_template': f''' - - -
-
-

🔔 系统通知

-
-
-

亲爱的 {recipient_name}

-

您有一条新的系统通知:

-
-

您的账户设置已更新,如果这不是您的操作,请立即联系我们。

-
-

系统会持续为您提供安全保障,如有疑问请联系客服。

- -
-

此致
Hertz Server Django 团队

-
-
- - - ''' - }, - 'verification': { - 'subject': '🔐 邮箱验证 - Hertz Server Django', - 'html_template': f''' - - -
-
-

🔐 邮箱验证

-
-
-

亲爱的 {recipient_name}

-

感谢您注册 Hertz Server Django!请点击下面的按钮验证您的邮箱地址:

- -

如果按钮无法点击,请复制以下链接到浏览器:
- http://127.0.0.1:8000/verify?token=demo_token

-

如果您没有注册账户,请忽略此邮件。此验证链接将在24小时后失效。

-
-

此致
Hertz Server Django 团队

-
-
- - - ''' - }, - 'custom': { - 'subject': custom_subject or '自定义邮件 - Hertz Server Django', - 'html_template': f''' - - -
-
-

{custom_subject or '自定义邮件'}

-
-
-

亲爱的 {recipient_name}

-
- {custom_message.replace(chr(10), '
') if custom_message else '这是一封自定义邮件。'} -
-
-

此致
Hertz Server Django 团队

-
-
- - - ''' - } - } - - template = email_templates.get(email_type, email_templates['welcome']) - html_content = template['html_template'] - text_content = strip_tags(html_content) - - return { - 'subject': template['subject'], - 'html_content': html_content, - 'text_content': text_content - } - -def send_demo_email(recipient_email, subject, html_content, text_content): - """发送演示邮件""" - try: - # 检查邮件配置 - if not settings.EMAIL_HOST_USER or not settings.EMAIL_HOST_PASSWORD: - print("邮件配置不完整,使用控制台输出模式") - print(f"收件人: {recipient_email}") - print(f"主题: {subject}") - print(f"内容: {text_content[:200]}...") - return True - - # 创建邮件 - email = EmailMultiAlternatives( - subject=subject, - body=text_content, - from_email=settings.DEFAULT_FROM_EMAIL, - to=[recipient_email] - ) - - # 添加HTML内容 - email.attach_alternative(html_content, "text/html") - - # 发送邮件 - email.send() - return True - - except Exception as e: - print(f"邮件发送失败: {str(e)}") - return False diff --git a/hertz_server_django/__init__.py b/hertz_server_django/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/hertz_server_django/asgi.py b/hertz_server_django/asgi.py deleted file mode 100644 index e3dbe88..0000000 --- a/hertz_server_django/asgi.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -ASGI config for hertz_server_django project. - -It exposes the ASGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/5.2/howto/deployment/asgi/ -""" - -import os - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') - -# Import Django first to ensure proper initialization -from django.core.asgi import get_asgi_application - -# Initialize Django ASGI application early to ensure the AppRegistry -# is populated before importing code that may import ORM models. -django_asgi_app = get_asgi_application() - -# Import other modules AFTER Django setup -from django.conf import settings -from channels.routing import ProtocolTypeRouter, URLRouter -from channels.auth import AuthMiddlewareStack -from channels.security.websocket import AllowedHostsOriginValidator - -# Import websocket routing AFTER Django setup to avoid AppRegistryNotReady -from hertz_demo import routing as demo_routing - -if 'hertz_studio_django_yolo' in settings.INSTALLED_APPS: - from hertz_studio_django_yolo import routing as yolo_routing - websocket_urlpatterns = ( - demo_routing.websocket_urlpatterns + - yolo_routing.websocket_urlpatterns - ) -else: - websocket_urlpatterns = demo_routing.websocket_urlpatterns - -# 在开发环境下放宽Origin校验,便于第三方客户端(如 Apifox、wscat)调试 -websocket_app = AuthMiddlewareStack( - URLRouter( - websocket_urlpatterns - ) -) - -if getattr(settings, 'DEBUG', False): - application = ProtocolTypeRouter({ - "http": django_asgi_app, - "websocket": websocket_app, - }) -else: - application = ProtocolTypeRouter({ - "http": django_asgi_app, - "websocket": AllowedHostsOriginValidator(websocket_app), - }) - diff --git a/hertz_server_django/settings.py b/hertz_server_django/settings.py deleted file mode 100644 index 8a2bda5..0000000 --- a/hertz_server_django/settings.py +++ /dev/null @@ -1,364 +0,0 @@ -""" -Django settings for hertz_server_django project. - -Generated by 'django-admin startproject' using Django 5.2.6. - -For more information on this file, see -https://docs.djangoproject.com/en/5.2/topics/settings/ - -For the full list of settings and their values, see -https://docs.djangoproject.com/en/5.2/ref/settings/ -""" - -import os -from pathlib import Path -from decouple import config - -# 修复DRF的ip_address_validators函数 -def fix_drf_ip_validators(): - """ - 修复DRF的ip_address_validators函数返回值问题 - """ - try: - from rest_framework import fields - - # 保存原始函数 - original_ip_address_validators = fields.ip_address_validators - - def fixed_ip_address_validators(protocol, unpack_ipv4): - """ - 修复后的ip_address_validators函数,确保返回两个值 - """ - validators = original_ip_address_validators(protocol, unpack_ipv4) - # 如果只返回了validators,添加默认的error_message - if isinstance(validators, list): - return validators, 'Enter a valid IP address.' - else: - # 如果已经返回了两个值,直接返回 - return validators - - # 应用猴子补丁 - fields.ip_address_validators = fixed_ip_address_validators - - except ImportError: - # 如果DRF未安装,忽略错误 - pass - -# 应用修复 -fix_drf_ip_validators() - -# Build paths inside the project like this: BASE_DIR / 'subdir'. -BASE_DIR = Path(__file__).resolve().parent.parent - - -# Quick-start development settings - unsuitable for production -# See https://docs.djangoproject.com/en/5.2/howto/deployment/checklist/ - -# SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = config('SECRET_KEY', default='django-insecure-0a1bx*8!97l^4z#ml#ufn_*9ut*)zlso$*k-g^h&(2=p@^51md') - -# SECURITY WARNING: don't run with debug turned on in production! -DEBUG = config('DEBUG', default=True, cast=bool) - -ALLOWED_HOSTS = config('ALLOWED_HOSTS', default='localhost,127.0.0.1', cast=lambda v: [s.strip() for s in v.split(',')]) - -# Database engine configuration (sqlite/mysql) with backward compatibility -# Prefer `DB_ENGINE` env var; fallback to legacy `USE_REDIS_AS_DB` -DB_ENGINE = config('DB_ENGINE', default=None) -USE_REDIS_AS_DB = config('USE_REDIS_AS_DB', default=True, cast=bool) -if DB_ENGINE is None: - DB_ENGINE = 'sqlite' if USE_REDIS_AS_DB else 'mysql' - -# Application definition - -INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - - # Third party apps - 'rest_framework', - 'corsheaders', - 'channels', - 'drf_spectacular', - 'hertz_studio_django_codegen', # 自动注册的应用 - - # 必备注册的app,不要删 - 'hertz_demo', # 初始化演示模块 - 'hertz_studio_django_captcha', # 验证码模块 - 'hertz_studio_django_auth', # 权限模块 - 'hertz_studio_django_system_monitor', # 系统监测模块 - 'hertz_studio_django_log', # 日志管理模块 - 'hertz_studio_django_notice', # 通知模块 - - # ======在下面导入你需要的app====== - 'hertz_studio_django_ai', #ai聊天模块 - 'hertz_studio_django_kb', # 知识库 ai和kb库是相互绑定的 - 'hertz_studio_django_wiki', # 文章模块 - 'hertz_studio_django_yolo', # YOLO目标检测模块 - 'hertz_studio_django_yolo_train', # Yolo训练模块 - -] - -MIDDLEWARE = [ - 'corsheaders.middleware.CorsMiddleware', - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'hertz_studio_django_auth.utils.middleware.AuthMiddleware', # 权限认证中间件 - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', -] - -ROOT_URLCONF = 'hertz_server_django.urls' - -TEMPLATES = [ - { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [BASE_DIR / 'templates'] - , - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', - ], - }, - }, -] - -WSGI_APPLICATION = 'hertz_server_django.wsgi.application' - - -# Database -# https://docs.djangoproject.com/en/5.2/ref/settings/#databases - -if DB_ENGINE == 'sqlite': - DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': BASE_DIR / 'data/db.sqlite3', - } - } - # Use Redis-backed sessions when on SQLite (optional, keeps prior behavior) - SESSION_ENGINE = 'django.contrib.sessions.backends.cache' - SESSION_CACHE_ALIAS = 'default' -elif DB_ENGINE == 'mysql': - DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.mysql', - 'NAME': config('DB_NAME', default='hertz_server'), - 'USER': config('DB_USER', default='root'), - 'PASSWORD': config('DB_PASSWORD', default='root'), - 'HOST': config('DB_HOST', default='localhost'), - 'PORT': config('DB_PORT', default='3306'), - 'OPTIONS': { - 'charset': 'utf8mb4', - }, - } - } -else: - # Fallback to SQLite for unexpected values - DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': BASE_DIR / 'data/db.sqlite3', - } - } - -# Redis -CACHES = { - 'default': { - 'BACKEND': 'django_redis.cache.RedisCache', - 'LOCATION': config('REDIS_URL', default='redis://127.0.0.1:6379/0'), - 'OPTIONS': { - 'CLIENT_CLASS': 'django_redis.client.DefaultClient', - } - } -} - - -# Password validation -# https://docs.djangoproject.com/en/5.2/ref/settings/#auth-password-validators - -AUTH_PASSWORD_VALIDATORS = [ - { - 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', - }, - { - 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', - }, -] - - -# Internationalization -# https://docs.djangoproject.com/en/5.2/topics/i18n/ - -LANGUAGE_CODE = 'en-us' - -TIME_ZONE = 'UTC' - -USE_I18N = True - -USE_TZ = True - - -# Static files (CSS, JavaScript, Images) -# https://docs.djangoproject.com/en/5.2/howto/static-files/ - -STATIC_URL = 'static/' -STATICFILES_DIRS = [ - BASE_DIR / 'static', -] - -# Media files (User uploaded files) -MEDIA_URL = '/media/' -MEDIA_ROOT = BASE_DIR / 'media' - -# Default primary key field type -# https://docs.djangoproject.com/en/5.2/ref/settings/#default-auto-field - -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' - -# Django REST Framework configuration -# 使用自定义AuthMiddleware进行认证,不使用DRF的认证和权限系统 -REST_FRAMEWORK = { - 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', - 'DEFAULT_AUTHENTICATION_CLASSES': [], # 不使用DRF认证类 - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.AllowAny', # 所有接口默认允许访问,由AuthMiddleware控制权限 - ], - 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', - 'PAGE_SIZE': 20, - 'DEFAULT_RENDERER_CLASSES': [ - 'rest_framework.renderers.JSONRenderer', - 'rest_framework.renderers.BrowsableAPIRenderer', - ], -} - -# Spectacular (OpenAPI 3.0) configuration -SPECTACULAR_SETTINGS = { - 'TITLE': 'Hertz Server API', - 'DESCRIPTION': 'API documentation for Hertz Server Django project', - 'VERSION': '1.0.0', - 'SERVE_INCLUDE_SCHEMA': False, - 'COMPONENT_SPLIT_REQUEST': True, - 'SCHEMA_PATH_PREFIX': '/api/', -} - -# CORS configuration -CORS_ALLOWED_ORIGINS = config( - 'CORS_ALLOWED_ORIGINS', - default='http://localhost:3000,http://127.0.0.1:3000', - cast=lambda v: [s.strip() for s in v.split(',')] -) - -CORS_ALLOW_CREDENTIALS = True - -CORS_ALLOW_ALL_ORIGINS = config('CORS_ALLOW_ALL_ORIGINS', default=False, cast=bool) - -# Captcha settings -CAPTCHA_IMAGE_SIZE = ( - config('CAPTCHA_IMAGE_SIZE_WIDTH', default=120, cast=int), - config('CAPTCHA_IMAGE_SIZE_HEIGHT', default=50, cast=int) -) -CAPTCHA_LENGTH = config('CAPTCHA_LENGTH', default=4, cast=int) -CAPTCHA_TIMEOUT = config('CAPTCHA_TIMEOUT', default=5, cast=int) # minutes -CAPTCHA_FONT_SIZE = config('CAPTCHA_FONT_SIZE', default=40, cast=int) -CAPTCHA_BACKGROUND_COLOR = config('CAPTCHA_BACKGROUND_COLOR', default='#ffffff') -CAPTCHA_FOREGROUND_COLOR = config('CAPTCHA_FOREGROUND_COLOR', default='#000000') -# 验证码词典文件路径 -CAPTCHA_WORDS_DICTIONARY = str(BASE_DIR / 'captcha_words.txt') -# 验证码挑战函数配置 -CAPTCHA_CHALLENGE_FUNCT = 'captcha.helpers.random_char_challenge' # 默认使用随机字符 -# 数学验证码配置 -CAPTCHA_MATH_CHALLENGE_OPERATOR = '+-*' -# 验证码噪声和过滤器 -CAPTCHA_NOISE_FUNCTIONS = ( - 'captcha.helpers.noise_arcs', - 'captcha.helpers.noise_dots', -) -CAPTCHA_FILTER_FUNCTIONS = ( - 'captcha.helpers.post_smooth', -) - -# Hertz Captcha settings (used by hertz_studio_django_captcha.captcha_generator.HertzCaptchaGenerator) -HERTZ_CAPTCHA_LENGTH = config('HERTZ_CAPTCHA_LENGTH', default=4, cast=int) -HERTZ_CAPTCHA_WIDTH = config('HERTZ_CAPTCHA_WIDTH', default=160, cast=int) -HERTZ_CAPTCHA_HEIGHT = config('HERTZ_CAPTCHA_HEIGHT', default=60, cast=int) -HERTZ_CAPTCHA_FONT_SIZE = config('HERTZ_CAPTCHA_FONT_SIZE', default=40, cast=int) -HERTZ_CAPTCHA_TIMEOUT = config('HERTZ_CAPTCHA_TIMEOUT', default=300, cast=int) -HERTZ_CAPTCHA_BACKGROUND_COLOR = config('HERTZ_CAPTCHA_BACKGROUND_COLOR', default='#ffffff') -HERTZ_CAPTCHA_FOREGROUND_COLOR = config('HERTZ_CAPTCHA_FOREGROUND_COLOR', default='#000000') -HERTZ_CAPTCHA_NOISE_LEVEL = config('HERTZ_CAPTCHA_NOISE_LEVEL', default=0.3, cast=float) -HERTZ_CAPTCHA_REDIS_KEY_PREFIX = config('HERTZ_CAPTCHA_REDIS_KEY_PREFIX', default='hertz_captcha:') -HERTZ_CAPTCHA_FONT_PATH = config('HERTZ_CAPTCHA_FONT_PATH', default=str(MEDIA_ROOT / 'arial.ttf')) - -# Email configuration -EMAIL_BACKEND = config('EMAIL_BACKEND', default='django.core.mail.backends.smtp.EmailBackend') -EMAIL_HOST = config('EMAIL_HOST', default='smtp.qq.com') -EMAIL_PORT = config('EMAIL_PORT', default=465, cast=int) -EMAIL_USE_SSL = config('EMAIL_USE_SSL', default=True, cast=bool) -EMAIL_USE_TLS = config('EMAIL_USE_TLS', default=False, cast=bool) -EMAIL_HOST_USER = config('EMAIL_HOST_USER', default='563161210@qq.com') -EMAIL_HOST_PASSWORD = config('EMAIL_HOST_PASSWORD', default='') -DEFAULT_FROM_EMAIL = config('DEFAULT_FROM_EMAIL', default='563161210@qq.com') - -# 注册邮箱验证码开关(0=关闭,1=开启) -REGISTER_EMAIL_VERIFICATION = config('REGISTER_EMAIL_VERIFICATION', default=0, cast=int) - -# Channels configuration for WebSocket support -ASGI_APPLICATION = 'hertz_server_django.asgi.application' - -# Channel layers configuration -CHANNEL_LAYERS = { - 'default': { - 'BACKEND': 'channels_redis.core.RedisChannelLayer', - 'CONFIG': { - "hosts": [config('REDIS_URL', default='redis://127.0.0.1:6379/2')], - }, - }, -} - -# 自定义用户模型 -AUTH_USER_MODEL = 'hertz_studio_django_auth.HertzUser' - -# JWT配置 -JWT_SECRET_KEY = config('JWT_SECRET_KEY', default=SECRET_KEY) -JWT_ALGORITHM = 'HS256' -JWT_ACCESS_TOKEN_LIFETIME = config('JWT_ACCESS_TOKEN_LIFETIME', default=60 * 60 * 24, cast=int) # 24小时 -JWT_REFRESH_TOKEN_LIFETIME = config('JWT_REFRESH_TOKEN_LIFETIME', default=60 * 60 * 24 * 7, cast=int) # 7天 - -# 权限系统配置 -HERTZ_AUTH_SETTINGS = { - 'SUPER_ADMIN_PERMISSIONS': ['*'], # 超级管理员拥有所有权限 - 'DEFAULT_PERMISSIONS': [], # 默认权限 -} - -# AuthMiddleware配置 - 不需要登录验证的URL模式(支持正则表达式) -NO_AUTH_PATTERNS = config( - 'NO_AUTH_PATTERNS', - default=r'^/api/auth/login/?$,^/api/auth/register/?$,^/api/auth/email/code/?$,^/api/auth/send-email-code/?$,^/api/auth/password/reset/?$,^/api/captcha/.*$,^/api/docs/.*$,^/api/redoc/.*$,^/api/schema/.*$,^/admin/.*$,^/static/.*$,^/media/.*$,^/demo/.*$,^/websocket/.*$,^/api/system/.*$', - cast=lambda v: [s.strip() for s in v.split(',')] -) - -# 密码加密配置 -PASSWORD_HASHERS = [ - 'hertz_studio_django_utils.crypto.MD5PasswordHasher', # 使用MD5加密 - 'django.contrib.auth.hashers.PBKDF2PasswordHasher', - 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', - 'django.contrib.auth.hashers.Argon2PasswordHasher', - 'django.contrib.auth.hashers.BCryptSHA256PasswordHasher', -] diff --git a/hertz_server_django/urls.py b/hertz_server_django/urls.py deleted file mode 100644 index 846b6c5..0000000 --- a/hertz_server_django/urls.py +++ /dev/null @@ -1,73 +0,0 @@ -""" -URL configuration for hertz_server_django project. - -The `urlpatterns` list routes URLs to views. For more information please see: - https://docs.djangoproject.com/en/5.2/topics/http/urls/ -Examples: -Function views - 1. Add an import: from my_app import views - 2. Add a URL to urlpatterns: path('', views.home, name='home') -Class-based views - 1. Add an import: from other_app.views import Home - 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') -Including another URLconf - 1. Import the include() function: from django.urls import include, path - 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) -""" -from django.urls import path, include -from django.conf import settings -from django.conf.urls.static import static -from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView -from . import views - -urlpatterns = [ - - # API documentation routes - path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='swagger-ui'), - path('api/redoc/', SpectacularRedocView.as_view(url_name='schema'), name='redoc'), - path('api/schema/', SpectacularAPIView.as_view(), name='schema'), - - # 首页路由 - path('', views.index, name='index'), - - # Hertz Captcha routes - path('api/captcha/', include('hertz_studio_django_captcha.urls')), - - # Hertz Auth routes - path('api/', include('hertz_studio_django_auth.urls')), - - # Demo app routes - path('', include('hertz_demo.urls')), - - # Hertz System Monitor routes - path('api/system/', include('hertz_studio_django_system_monitor.urls')), - - # Hertz Log routes - path('api/log/', include('hertz_studio_django_log.urls')), - - # Hertz Notice routes - path('api/notice/', include('hertz_studio_django_notice.urls')), - - # ===========在下面添加你需要的路由=========== - # Hertz AI routes - path('api/ai/', include('hertz_studio_django_ai.urls')), - - # Hertz Knowledge Base routes - path('api/kb/', include('hertz_studio_django_kb.urls')), - - # Hertz Wiki routes - path('api/wiki/', include('hertz_studio_django_wiki.urls')), - - # Hertz YOLO routes - path('api/yolo/', include('hertz_studio_django_yolo.urls')), - - # YOLO 训练管理 - path('api/yolo/train/', include('hertz_studio_django_yolo_train.urls')), - - -] - -# 在开发环境下提供媒体文件服务 -if settings.DEBUG: - urlpatterns += static(settings.MEDIA_URL, document_root=settings.MEDIA_ROOT) - urlpatterns += static(settings.STATIC_URL, document_root=settings.STATICFILES_DIRS[0]) diff --git a/hertz_server_django/views.py b/hertz_server_django/views.py deleted file mode 100644 index 5d1f7e8..0000000 --- a/hertz_server_django/views.py +++ /dev/null @@ -1,13 +0,0 @@ -from django.contrib.auth.models import Permission -from django.shortcuts import render - -from hertz_studio_django_auth.utils.decorators import no_login_required - - -@no_login_required -def index(request): - """ - 系统首页视图 - 展示系统的基础介绍和功能特性 - """ - return render(request, 'index.html') \ No newline at end of file diff --git a/hertz_server_django/wsgi.py b/hertz_server_django/wsgi.py deleted file mode 100644 index 1e9fdf9..0000000 --- a/hertz_server_django/wsgi.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -WSGI config for hertz_server_django project. - -It exposes the WSGI callable as a module-level variable named ``application``. - -For more information on this file, see -https://docs.djangoproject.com/en/5.2/howto/deployment/wsgi/ -""" - -import os - -from django.core.wsgi import get_wsgi_application - -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') - -application = get_wsgi_application() diff --git a/hertz_server_django_ui/.editorconfig b/hertz_server_django_ui/.editorconfig deleted file mode 100644 index 5534254..0000000 --- a/hertz_server_django_ui/.editorconfig +++ /dev/null @@ -1,25 +0,0 @@ -# EditorConfig配置文件 -root = true - -[*] -charset = utf-8 -indent_style = space -indent_size = 2 -end_of_line = lf -insert_final_newline = true -trim_trailing_whitespace = true - -[*.md] -trim_trailing_whitespace = false - -[*.{yml,yaml}] -indent_size = 2 - -[*.{js,ts,vue}] -indent_size = 2 - -[*.json] -indent_size = 2 - -[*.{css,scss,sass}] -indent_size = 2 diff --git a/hertz_server_django_ui/.env b/hertz_server_django_ui/.env deleted file mode 100644 index 3c57600..0000000 --- a/hertz_server_django_ui/.env +++ /dev/null @@ -1,10 +0,0 @@ -# API 基础地址 -VITE_API_BASE_URL=http://localhost:8000 - -# 应用配置 -VITE_APP_TITLE=Hertz Admin -VITE_APP_VERSION=1.0.0 - -# 开发服务器配置 -VITE_DEV_SERVER_HOST=localhost -VITE_DEV_SERVER_PORT=3000 diff --git a/hertz_server_django_ui/.env.development b/hertz_server_django_ui/.env.development deleted file mode 100644 index 638a3b8..0000000 --- a/hertz_server_django_ui/.env.development +++ /dev/null @@ -1,2 +0,0 @@ -VITE_API_BASE_URL=http://localhost:8000 -VITE_TEMPLATE_SETUP_MODE=true \ No newline at end of file diff --git a/hertz_server_django_ui/.env.production b/hertz_server_django_ui/.env.production deleted file mode 100644 index 6178d6e..0000000 --- a/hertz_server_django_ui/.env.production +++ /dev/null @@ -1,2 +0,0 @@ -VITE_API_BASE_URL=http://localhost:8000 -VITE_TEMPLATE_SETUP_MODE=true diff --git a/hertz_server_django_ui/.gitignore b/hertz_server_django_ui/.gitignore deleted file mode 100644 index a547bf3..0000000 --- a/hertz_server_django_ui/.gitignore +++ /dev/null @@ -1,24 +0,0 @@ -# Logs -logs -*.log -npm-debug.log* -yarn-debug.log* -yarn-error.log* -pnpm-debug.log* -lerna-debug.log* - -node_modules -dist -dist-ssr -*.local - -# Editor directories and files -.vscode/* -!.vscode/extensions.json -.idea -.DS_Store -*.suo -*.ntvs* -*.njsproj -*.sln -*.sw? diff --git a/hertz_server_django_ui/README.md b/hertz_server_django_ui/README.md deleted file mode 100644 index d3c72fa..0000000 --- a/hertz_server_django_ui/README.md +++ /dev/null @@ -1,327 +0,0 @@ -
- -

通用大模型模板 · Hertz Admin + AI

- -现代化的管理后台前端模板,面向二次开发的前端工程师。内置账号体系、权限路由、主题美化、知识库、YOLO 模型全流程(管理 / 类别 / 告警 / 历史)等典型模块。 - -

-基于 Vite + Vue 3 + TypeScript + Ant Design Vue + Pinia + Vue Router 构建 -

- -
- ---- - -## ✨ 特性(面向前端) - -- **工程化完善**:TS 强类型、模块化 API、统一请求封装、权限化菜单/路由 -- **设计统一**:全局“超现代风格”主题,卡片 / 弹窗 / 按钮 / 输入 / 分页风格一致 -- **业务可复用**: - - 文章管理:分类树 + 列表搜索 + 编辑/发布 - - YOLO 模型:模型管理、模型类别管理、告警处理中心、检测历史管理 - - AI 助手:多会话列表 + 消息记录 + 多布局对话界面(含错误调试信息) - - 认证体系:登录/注册、验证码 -- **可扩展**:清晰的目录划分和命名规范,方便直接加模块或替换现有实现 - -## 🧩 技术栈 - -- 构建:Vite -- 语言:TypeScript -- 框架:Vue 3(Composition API) -- UI:Ant Design Vue -- 状态:Pinia -- 路由:Vue Router - -## 📦 项目结构与职责 - -> 根目录:`通用大模型模板/` - -```bash -通用大模型模板/ -└─ hertz_server_diango_ui_2/ # 前端工程(Vite) - ├─ public/ # 公共静态资源(不走打包器) - ├─ src/ - │ ├─ api/ # 接口定义(auth / yolo / knowledge / captcha / ai ...) - │ │ └─ yolo.ts # YOLO 模型 & 检测 & 类别相关 API - │ ├─ locales/ # 国际化文案 - │ ├─ router/ # 路由与菜单配置 - │ │ ├─ admin_menu.ts # 管理端菜单 + 路由映射(权限 key) - │ │ ├─ user_menu_ai.ts # 用户端菜单 + 路由映射(含 AI 助手) - │ │ └─ index.ts # Vue Router 实例 + 全局路由守卫 - │ ├─ stores/ # Pinia Store - │ │ ├─ hertz_app.ts # 全局应用设置(语言、布局、菜单折叠等) - │ │ ├─ hertz_user.ts # 用户 / 鉴权状态 - │ │ └─ hertz_theme.ts # 主题配置与 CSS 变量 - │ ├─ styles/ # 全局样式与变量 - │ │ ├─ index.scss # 全局组件风格覆盖(Button / Table / Modal ...) - │ │ └─ variables.scss # 主题色、阴影、圆角等变量 - │ ├─ utils/ # 工具方法 & 基础设施 - │ │ ├─ hertz_request.ts # Axios 封装(baseURL、拦截器、错误提示) - │ │ ├─ hertz_url.ts # 统一 URL 构造(API / 媒体 / WebSocket) - │ │ ├─ hertz_env.ts # 读取 & 校验 env 变量 - │ │ └─ hertz_router_utils.ts # 路由相关工具 & 调试 - │ ├─ views/ # 所有页面 - │ │ ├─ admin_page/ # 管理端页面 - │ │ │ ├─ ModelManagement.vue # YOLO 模型管理 - │ │ │ ├─ AlertLevelManagement.vue # 模型类别管理 - │ │ │ ├─ DetectionHistoryManagement.vue # 检测历史管理 - │ │ │ └─ ... # 其他管理端模块 - │ │ ├─ user_pages/ # 用户端页面(检测端 + AI 助手) - │ │ │ ├─ index.vue # 用户端主布局 + 顶部导航 - │ │ │ ├─ AiChat.vue # AI 助手对话页面 - │ │ │ ├─ YoloDetection.vue # 离线检测页面 - │ │ │ ├─ LiveDetection.vue # 实时检测页面(WebSocket) - │ │ │ └─ ... # 告警中心 / 通知中心 / 知识库等 - │ │ ├─ Login.vue # 登录页 - │ │ └─ register.vue # 注册页 - │ ├─ App.vue # 应用根组件 - │ └─ main.ts # 入口文件(挂载 Vue / 路由 / Pinia) - ├─ .env.development # 开发环境变量(前端专用) - ├─ .env.production # 生产构建环境变量 - ├─ vite.config.ts # Vite 配置(代理、构建、别名等) - └─ package.json -``` - -## 📁 文件与命名规范(建议) - -- **组件 / 页面** - - 页面:`src/views/admin_page/FooBarManagement.vue`,以业务 + Management 命名 - - 纯组件:放到 `src/components/`,使用大驼峰命名,如 `UserSelector.vue` -- **接口文件** - - 同一业务一个文件:`src/api/yolo.ts`、`src/api/auth.ts` - - 内部导出 `xxxApi` 对象 + TS 类型:`type AlertLevel`, `type YoloModel` 等 -- **样式** - - 全局或主题相关:放 `src/styles/`(注意不要在这里写页面私有样式) - - 单页面样式:使用 ` diff --git a/hertz_server_django_ui/src/api/ai.ts b/hertz_server_django_ui/src/api/ai.ts deleted file mode 100644 index a31ade1..0000000 --- a/hertz_server_django_ui/src/api/ai.ts +++ /dev/null @@ -1,96 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 通用响应类型 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 会话与消息类型 -export interface AIChatItem { - id: number - title: string - created_at: string - updated_at: string - latest_message?: string -} - -export interface AIChatDetail { - id: number - title: string - created_at: string - updated_at: string -} - -export interface AIChatMessage { - id: number - role: 'user' | 'assistant' | 'system' - content: string - created_at: string -} - -export interface ChatListData { - total: number - page: number - page_size: number - chats: AIChatItem[] -} - -export interface ChatDetailData { - chat: AIChatDetail - messages: AIChatMessage[] -} - -export interface SendMessageData { - user_message: AIChatMessage - ai_message: AIChatMessage -} - -// 将后端可能返回的 chat_id 统一规范为 id -const normalizeChatItem = (raw: any): AIChatItem => ({ - id: typeof raw?.id === 'number' ? raw.id : Number(raw?.chat_id), - title: raw?.title, - created_at: raw?.created_at, - updated_at: raw?.updated_at, - latest_message: raw?.latest_message, -}) - -const normalizeChatDetail = (raw: any): AIChatDetail => ({ - id: typeof raw?.id === 'number' ? raw.id : Number(raw?.chat_id), - title: raw?.title, - created_at: raw?.created_at, - updated_at: raw?.updated_at, -}) - -export const aiApi = { - listChats: (params?: { query?: string; page?: number; page_size?: number }): Promise> => - request.get('/api/ai/chats/', { params, showError: false }).then((resp: any) => { - if (resp?.data?.chats && Array.isArray(resp.data.chats)) { - resp.data.chats = resp.data.chats.map((c: any) => normalizeChatItem(c)) - } - return resp as ApiResponse - }), - - createChat: (body?: { title?: string }): Promise> => - request.post('/api/ai/chats/create/', body || { title: '新对话' }).then((resp: any) => { - if (resp?.data) resp.data = normalizeChatDetail(resp.data) - return resp as ApiResponse - }), - - getChatDetail: (chatId: number): Promise> => - request.get(`/api/ai/chats/${chatId}/`).then((resp: any) => { - if (resp?.data?.chat) resp.data.chat = normalizeChatDetail(resp.data.chat) - return resp as ApiResponse - }), - - updateChat: (chatId: number, body: { title: string }): Promise> => - request.put(`/api/ai/chats/${chatId}/update/`, body), - - deleteChats: (chatIds: number[]): Promise> => - request.post('/api/ai/chats/delete/', { chat_ids: chatIds }), - - sendMessage: (chatId: number, body: { content: string }): Promise> => - request.post(`/api/ai/chats/${chatId}/send/`, body), -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/auth.ts b/hertz_server_django_ui/src/api/auth.ts deleted file mode 100644 index 04e6702..0000000 --- a/hertz_server_django_ui/src/api/auth.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 注册接口数据类型 -export interface RegisterData { - username: string - password: string - confirm_password: string - email: string - phone: string - real_name: string - captcha: string - captcha_id: string -} - -// 发送邮箱验证码数据类型 -export interface SendEmailCodeData { - email: string - code_type: string -} - -// 登录接口数据类型 -export interface LoginData { - username: string - password: string - captcha_code: string - captcha_key: string -} - -// 注册API -export const registerUser = (data: RegisterData) => { - return request.post('/api/auth/register/', data) -} - -// 登录API -export const loginUser = (data: LoginData) => { - return request.post('/api/auth/login/', data) -} - -// 发送邮箱验证码API -export const sendEmailCode = (data: SendEmailCodeData) => { - return request.post('/api/auth/email/code/', data) -} - -// 登出API -export const logoutUser = () => { - return request.post('/api/auth/logout/') -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/captcha.ts b/hertz_server_django_ui/src/api/captcha.ts deleted file mode 100644 index be92632..0000000 --- a/hertz_server_django_ui/src/api/captcha.ts +++ /dev/null @@ -1,89 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 验证码相关接口类型定义 -export interface CaptchaResponse { - captcha_id: string - image_data: string // base64编码的图片 - expires_in: number // 过期时间(秒) -} - -export interface CaptchaRefreshResponse { - captcha_id: string - image_data: string // base64编码的图片 - expires_in: number // 过期时间(秒) -} - -/** - * 生成验证码 - */ -export const generateCaptcha = async (): Promise => { - console.log('🚀 开始发送验证码生成请求...') - console.log('📍 请求URL:', `${import.meta.env.VITE_API_BASE_URL}/api/captcha/generate/`) - console.log('🌐 环境变量 VITE_API_BASE_URL:', import.meta.env.VITE_API_BASE_URL) - - try { - const response = await request.post<{ - code: number - message: string - data: CaptchaResponse - }>('/api/captcha/generate/') - - console.log('✅ 验证码生成请求成功:', response) - return response.data - } catch (error: any) { - console.error('❌ 验证码生成请求失败 - 完整错误信息:') - console.error('错误对象:', error) - console.error('错误类型:', typeof error) - console.error('错误消息:', error?.message) - console.error('错误代码:', error?.code) - console.error('错误状态:', error?.status) - console.error('错误响应:', error?.response) - console.error('错误请求:', error?.request) - console.error('错误配置:', error?.config) - - // 检查是否是网络错误 - if (error?.code === 'NETWORK_ERROR' || error?.message?.includes('Network Error')) { - console.error('🌐 网络连接错误 - 可能的原因:') - console.error('1. 后端服务器未启动') - console.error('2. API地址不正确') - console.error('3. CORS配置问题') - console.error('4. 防火墙阻止连接') - } - - throw error - } -} - -/** - * 刷新验证码 - */ -export const refreshCaptcha = async (captcha_id: string): Promise => { - console.log('🔄 开始发送验证码刷新请求...') - console.log('📍 请求URL:', `${import.meta.env.VITE_API_BASE_URL}/api/captcha/refresh/`) - console.log('📦 请求数据:', { captcha_id }) - - try { - const response = await request.post<{ - code: number - message: string - data: CaptchaRefreshResponse - }>('/api/captcha/refresh/', { - captcha_id - }) - - console.log('✅ 验证码刷新请求成功:', response) - return response.data - } catch (error: any) { - console.error('❌ 验证码刷新请求失败 - 完整错误信息:') - console.error('错误对象:', error) - console.error('错误类型:', typeof error) - console.error('错误消息:', error?.message) - console.error('错误代码:', error?.code) - console.error('错误状态:', error?.status) - console.error('错误响应:', error?.response) - console.error('错误请求:', error?.request) - console.error('错误配置:', error?.config) - - throw error - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/dashboard.ts b/hertz_server_django_ui/src/api/dashboard.ts deleted file mode 100644 index 3254ae0..0000000 --- a/hertz_server_django_ui/src/api/dashboard.ts +++ /dev/null @@ -1,393 +0,0 @@ -import { request } from '@/utils/hertz_request' -import { logApi, type OperationLogListItem } from './log' -import { systemMonitorApi, type SystemInfo, type CpuInfo, type MemoryInfo, type DiskInfo } from './system_monitor' -import { noticeUserApi } from './notice_user' -import { knowledgeApi } from './knowledge' - -// 仪表盘统计数据类型定义 -export interface DashboardStats { - totalUsers: number - totalNotifications: number - totalLogs: number - totalKnowledge: number - userGrowthRate: number - notificationGrowthRate: number - logGrowthRate: number - knowledgeGrowthRate: number -} - -// 最近活动数据类型 -export interface RecentActivity { - id: number - action: string - time: string - user: string - type: 'login' | 'create' | 'update' | 'system' | 'register' -} - -// 系统状态数据类型 -export interface SystemStatus { - cpuUsage: number - memoryUsage: number - diskUsage: number - networkStatus: 'normal' | 'warning' | 'error' -} - -// 访问趋势数据类型 -export interface VisitTrend { - date: string - visits: number - users: number -} - -// 仪表盘数据汇总类型 -export interface DashboardData { - stats: DashboardStats - recentActivities: RecentActivity[] - systemStatus: SystemStatus - visitTrends: VisitTrend[] -} - -// API响应类型 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 仪表盘API接口 -export const dashboardApi = { - // 获取仪表盘统计数据 - getStats: (): Promise> => { - return request.get('/api/dashboard/stats/') - }, - - // 获取真实统计数据 - getRealStats: async (): Promise> => { - try { - // 并行获取各种统计数据 - const [notificationStats, logStats, knowledgeStats] = await Promise.all([ - noticeUserApi.statistics().catch(() => ({ success: false, data: { total_count: 0, unread_count: 0 } })), - logApi.getList({ page: 1, page_size: 1 }).catch(() => ({ success: false, data: { count: 0 } })), - knowledgeApi.getArticles({ page: 1, page_size: 1 }).catch(() => ({ success: false, data: { total: 0 } })) - ]) - - // 计算统计数据 - const totalNotifications = notificationStats.success ? (notificationStats.data.total_count || 0) : 0 - - // 处理日志数据 - 兼容多种返回结构 - let totalLogs = 0 - if (logStats.success && logStats.data) { - const logData = logStats.data as any - console.log('日志API响应数据:', logData) - // 兼容DRF标准结构:{ count, next, previous, results } - if ('count' in logData) { - totalLogs = Number(logData.count) || 0 - } else if ('total' in logData) { - totalLogs = Number(logData.total) || 0 - } else if ('total_count' in logData) { - totalLogs = Number(logData.total_count) || 0 - } else if (logData.pagination && logData.pagination.total_count) { - totalLogs = Number(logData.pagination.total_count) || 0 - } - console.log('解析出的日志总数:', totalLogs) - } else { - console.log('日志API调用失败:', logStats) - } - - const totalKnowledge = knowledgeStats.success ? (knowledgeStats.data.total || 0) : 0 - - console.log('统计数据汇总:', { totalNotifications, totalLogs, totalKnowledge }) - - // 模拟增长率(实际项目中应该从后端获取) - const stats: DashboardStats = { - totalUsers: 0, // 暂时设为0,需要用户管理API - totalNotifications, - totalLogs, - totalKnowledge, - userGrowthRate: 0, - notificationGrowthRate: Math.floor(Math.random() * 20) - 10, // 模拟 -10% 到 +10% - logGrowthRate: Math.floor(Math.random() * 30) - 15, // 模拟 -15% 到 +15% - knowledgeGrowthRate: Math.floor(Math.random() * 25) - 12 // 模拟 -12% 到 +13% - } - - return { - success: true, - code: 200, - message: 'success', - data: stats - } - } catch (error) { - console.error('获取真实统计数据失败:', error) - return { - success: false, - code: 500, - message: '获取统计数据失败', - data: { - totalUsers: 0, - totalNotifications: 0, - totalLogs: 0, - totalKnowledge: 0, - userGrowthRate: 0, - notificationGrowthRate: 0, - logGrowthRate: 0, - knowledgeGrowthRate: 0 - } - } - } - }, - - // 获取最近活动(从日志接口) - getRecentActivities: async (limit: number = 10): Promise> => { - try { - const response = await logApi.getList({ page: 1, page_size: limit }) - if (response.success && response.data) { - // 根据实际API响应结构,数据可能在data.logs或data.results中 - const logs = (response.data as any).logs || (response.data as any).results || [] - const activities: RecentActivity[] = logs.map((log: any) => ({ - id: log.log_id || log.id, - action: log.description || log.operation_description || `${log.action_type_display || log.operation_type} - ${log.module || log.operation_module}`, - time: formatTimeAgo(log.created_at), - user: log.username || log.user?.username || '未知用户', - type: mapLogTypeToActivityType(log.action_type || log.operation_type) - })) - return { - success: true, - code: 200, - message: 'success', - data: activities - } - } - return { - success: false, - code: 500, - message: '获取活动数据失败', - data: [] - } - } catch (error) { - console.error('获取最近活动失败:', error) - return { - success: false, - code: 500, - message: '获取活动数据失败', - data: [] - } - } - }, - - // 获取系统状态(从系统监控接口) - getSystemStatus: async (): Promise> => { - try { - const [cpuResponse, memoryResponse, disksResponse] = await Promise.all([ - systemMonitorApi.getCpu(), - systemMonitorApi.getMemory(), - systemMonitorApi.getDisks() - ]) - - if (cpuResponse.success && memoryResponse.success && disksResponse.success) { - // 根据实际API响应结构映射数据 - const systemStatus: SystemStatus = { - // CPU使用率:从 cpu_percent 字段获取 - cpuUsage: Math.round(cpuResponse.data.cpu_percent || 0), - // 内存使用率:从 percent 字段获取 - memoryUsage: Math.round(memoryResponse.data.percent || 0), - // 磁盘使用率:从磁盘数组的第一个磁盘的 percent 字段获取 - diskUsage: disksResponse.data.length > 0 ? Math.round(disksResponse.data[0].percent || 0) : 0, - networkStatus: 'normal' as const - } - - return { - success: true, - code: 200, - message: 'success', - data: systemStatus - } - } - - return { - success: false, - code: 500, - message: '获取系统状态失败', - data: { - cpuUsage: 0, - memoryUsage: 0, - diskUsage: 0, - networkStatus: 'error' as const - } - } - } catch (error) { - console.error('获取系统状态失败:', error) - return { - success: false, - code: 500, - message: '获取系统状态失败', - data: { - cpuUsage: 0, - memoryUsage: 0, - diskUsage: 0, - networkStatus: 'error' as const - } - } - } - }, - - // 获取访问趋势 - getVisitTrends: (period: 'week' | 'month' | 'year' = 'week'): Promise> => { - return request.get('/api/dashboard/visit-trends/', { params: { period } }) - }, - - // 获取完整仪表盘数据 - getDashboardData: (): Promise> => { - return request.get('/api/dashboard/overview/') - }, - - // 模拟数据方法(用于开发阶段) - getMockStats: (): Promise => { - return new Promise((resolve) => { - setTimeout(() => { - resolve({ - totalUsers: 1128, - todayVisits: 893, - totalOrders: 234, - totalRevenue: 12560.50, - userGrowthRate: 12, - visitGrowthRate: 8, - orderGrowthRate: -3, - revenueGrowthRate: 15 - }) - }, 500) - }) - }, - - getMockActivities: (): Promise => { - return new Promise((resolve) => { - setTimeout(() => { - resolve([ - { - id: 1, - action: '用户 张三 登录了系统', - time: '2分钟前', - user: '张三', - type: 'login' - }, - { - id: 2, - action: '管理员 李四 创建了新部门', - time: '5分钟前', - user: '李四', - type: 'create' - }, - { - id: 3, - action: '用户 王五 修改了个人信息', - time: '10分钟前', - user: '王五', - type: 'update' - }, - { - id: 4, - action: '系统自动备份完成', - time: '1小时前', - user: '系统', - type: 'system' - }, - { - id: 5, - action: '新用户 赵六 注册成功', - time: '2小时前', - user: '赵六', - type: 'register' - } - ]) - }, 300) - }) - }, - - getMockSystemStatus: (): Promise => { - return new Promise((resolve) => { - setTimeout(() => { - resolve({ - cpuUsage: 45, - memoryUsage: 67, - diskUsage: 32, - networkStatus: 'normal' - }) - }, 200) - }) - }, - - getMockVisitTrends: (period: 'week' | 'month' | 'year' = 'week'): Promise => { - return new Promise((resolve) => { - setTimeout(() => { - const data = { - week: [ - { date: '周一', visits: 120, users: 80 }, - { date: '周二', visits: 150, users: 95 }, - { date: '周三', visits: 180, users: 110 }, - { date: '周四', visits: 200, users: 130 }, - { date: '周五', visits: 250, users: 160 }, - { date: '周六', visits: 180, users: 120 }, - { date: '周日', visits: 160, users: 100 } - ], - month: [ - { date: '第1周', visits: 800, users: 500 }, - { date: '第2周', visits: 950, users: 600 }, - { date: '第3周', visits: 1100, users: 700 }, - { date: '第4周', visits: 1200, users: 750 } - ], - year: [ - { date: '1月', visits: 3200, users: 2000 }, - { date: '2月', visits: 3800, users: 2400 }, - { date: '3月', visits: 4200, users: 2600 }, - { date: '4月', visits: 3900, users: 2300 }, - { date: '5月', visits: 4500, users: 2800 }, - { date: '6月', visits: 5000, users: 3100 } - ] - } - resolve(data[period]) - }, 400) - }) - } -} - -// 辅助函数:格式化时间为相对时间 -function formatTimeAgo(dateString: string): string { - const now = new Date() - const date = new Date(dateString) - const diffInSeconds = Math.floor((now.getTime() - date.getTime()) / 1000) - - if (diffInSeconds < 60) { - return `${diffInSeconds}秒前` - } else if (diffInSeconds < 3600) { - const minutes = Math.floor(diffInSeconds / 60) - return `${minutes}分钟前` - } else if (diffInSeconds < 86400) { - const hours = Math.floor(diffInSeconds / 3600) - return `${hours}小时前` - } else { - const days = Math.floor(diffInSeconds / 86400) - return `${days}天前` - } -} - -// 辅助函数:将日志操作类型映射为活动类型 -function mapLogTypeToActivityType(operationType: string): RecentActivity['type'] { - if (!operationType) return 'system' - - const lowerType = operationType.toLowerCase() - - if (lowerType.includes('login') || lowerType.includes('登录')) { - return 'login' - } else if (lowerType.includes('create') || lowerType.includes('创建') || lowerType.includes('add') || lowerType.includes('新增')) { - return 'create' - } else if (lowerType.includes('update') || lowerType.includes('修改') || lowerType.includes('edit') || lowerType.includes('更新')) { - return 'update' - } else if (lowerType.includes('register') || lowerType.includes('注册')) { - return 'register' - } else if (lowerType.includes('view') || lowerType.includes('查看') || lowerType.includes('get') || lowerType.includes('获取')) { - return 'system' - } else { - return 'system' - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/department.ts b/hertz_server_django_ui/src/api/department.ts deleted file mode 100644 index 36e850b..0000000 --- a/hertz_server_django_ui/src/api/department.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 部门数据类型定义 -export interface Department { - dept_id: number - parent_id: number | null - dept_name: string - dept_code: string - leader: string - phone: string | null - email: string | null - status: number - sort_order: number - created_at: string - updated_at: string - children?: Department[] - user_count?: number -} - -// API响应类型 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 部门列表数据类型 -export interface DepartmentListData { - list: Department[] - total: number - page: number - page_size: number -} - -export type DepartmentListResponse = ApiResponse - -// 部门列表查询参数 -export interface DepartmentListParams { - page?: number - page_size?: number - search?: string - status?: number - parent_id?: number -} - -// 创建部门参数 -export interface CreateDepartmentParams { - parent_id: null - dept_name: string - dept_code: string - leader: string - phone: string - email: string - status: number - sort_order: number -} - -// 更新部门参数 -export type UpdateDepartmentParams = Partial - -// 部门API接口 -export const departmentApi = { - // 获取部门列表 - getDepartmentList: (params?: DepartmentListParams): Promise> => { - return request.get('/api/departments/', { params }) - }, - - // 获取部门详情 - getDepartment: (id: number): Promise> => { - return request.get(`/api/departments/${id}/`) - }, - - // 创建部门 - createDepartment: (data: CreateDepartmentParams): Promise> => { - return request.post('/api/departments/create/', data) - }, - - // 更新部门 - updateDepartment: (id: number, data: UpdateDepartmentParams): Promise> => { - return request.put(`/api/departments/${id}/update/`, data) - }, - - // 删除部门 - deleteDepartment: (id: number): Promise> => { - return request.delete(`/api/departments/${id}/delete/`) - }, - - // 获取部门树 - getDepartmentTree: (): Promise> => { - return request.get('/api/departments/tree/') - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/index.ts b/hertz_server_django_ui/src/api/index.ts deleted file mode 100644 index 7909944..0000000 --- a/hertz_server_django_ui/src/api/index.ts +++ /dev/null @@ -1,17 +0,0 @@ -// API 统一出口文件 -export * from './captcha' -export * from './auth' -export * from './user' -export * from './department' -export * from './menu' -export * from './role' -export * from './password' -export * from './system_monitor' -export * from './dashboard' - -export * from './ai' -// 这里可以继续添加其它 API 模块的导出,例如: -// export * from './admin' -export * from './log' -export * from './knowledge' -export * from './kb' diff --git a/hertz_server_django_ui/src/api/kb.ts b/hertz_server_django_ui/src/api/kb.ts deleted file mode 100644 index ee1b721..0000000 --- a/hertz_server_django_ui/src/api/kb.ts +++ /dev/null @@ -1,131 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 通用响应结构(与后端 HertzResponse 对齐) -export interface KbApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 知识库条目 -export interface KbItem { - id: number - title: string - modality: 'text' | 'code' | 'image' | 'audio' | 'video' | string - source_type: 'text' | 'file' | 'url' | string - chunk_count?: number - created_at?: string - updated_at?: string - created_chunk_count?: number - // 允许后端扩展字段 - [key: string]: any -} - -export interface KbItemListParams { - query?: string - page?: number - page_size?: number -} - -export interface KbItemListData { - total: number - page: number - page_size: number - list: KbItem[] -} - -// 语义搜索 -export interface KbSearchParams { - q: string - k?: number -} - -// 问答(RAG) -export interface KbQaPayload { - question: string - k?: number -} - -export interface KbQaData { - answer: string - [key: string]: any -} - -// 图谱查询参数(实体 / 关系) -export interface KbGraphListParams { - query?: string - page?: number - page_size?: number - // 关系检索可选参数 - source?: number - target?: number - relation_type?: string -} - -export const kbApi = { - // 知识库条目:列表 - listItems(params?: KbItemListParams): Promise> { - return request.get('/api/kb/items/list/', { params }) - }, - - // 语义搜索 - search(params: KbSearchParams): Promise> { - return request.get('/api/kb/search/', { params }) - }, - - // 问答(RAG) - qa(payload: KbQaPayload): Promise> { - return request.post('/api/kb/qa/', payload) - }, - - // 图谱:实体列表 - listEntities(params?: KbGraphListParams): Promise> { - return request.get('/api/kb/graph/entities/', { params }) - }, - - // 图谱:关系列表 - listRelations(params?: KbGraphListParams): Promise> { - return request.get('/api/kb/graph/relations/', { params }) - }, - - // 知识库条目:创建(JSON 文本) - createItemJson(payload: { title: string; modality?: string; source_type?: string; content?: string; metadata?: any }): Promise> { - return request.post('/api/kb/items/create/', payload) - }, - - // 知识库条目:创建(文件上传) - createItemFile(formData: FormData): Promise> { - return request.post('/api/kb/items/create/', formData) - }, - - // 图谱:创建实体 - createEntity(payload: { name: string; type: string; properties?: any }): Promise> { - return request.post('/api/kb/graph/entities/', payload) - }, - - // 图谱:更新实体 - updateEntity(id: number, payload: { name?: string; type?: string; properties?: any }): Promise> { - return request.put(`/api/kb/graph/entities/${id}/`, payload) - }, - - // 图谱:删除实体 - deleteEntity(id: number): Promise> { - return request.delete(`/api/kb/graph/entities/${id}/`) - }, - - // 图谱:创建关系 - createRelation(payload: { source: number; target: number; relation_type: string; properties?: any; source_chunk?: number }): Promise> { - return request.post('/api/kb/graph/relations/', payload) - }, - - // 图谱:删除关系 - deleteRelation(id: number): Promise> { - return request.delete(`/api/kb/graph/relations/${id}/`) - }, - - // 图谱:自动抽取实体与关系 - extractGraph(payload: { text?: string; item_id?: number }): Promise> { - return request.post('/api/kb/graph/extract/', payload) - }, -} diff --git a/hertz_server_django_ui/src/api/knowledge.ts b/hertz_server_django_ui/src/api/knowledge.ts deleted file mode 100644 index dbc5314..0000000 --- a/hertz_server_django_ui/src/api/knowledge.ts +++ /dev/null @@ -1,173 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 通用响应结构 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 分类类型 -export interface KnowledgeCategory { - id: number - name: string - description?: string - parent?: number | null - parent_name?: string | null - sort_order?: number - is_active?: boolean - created_at?: string - updated_at?: string - children_count?: number - articles_count?: number - full_path?: string - children?: KnowledgeCategory[] -} - -export interface CategoryListData { - list: KnowledgeCategory[] - total: number - page: number - page_size: number -} - -export interface CategoryListParams { - page?: number - page_size?: number - name?: string - parent_id?: number - is_active?: boolean -} - -// 文章类型 -export interface KnowledgeArticleListItem { - id: number - title: string - summary?: string | null - image?: string | null - category_name: string - author_name: string - status: 'draft' | 'published' | 'archived' - status_display: string - view_count?: number - created_at: string - updated_at: string - published_at?: string | null -} - -export interface KnowledgeArticleDetail extends KnowledgeArticleListItem { - content: string - category: number - author: number - tags?: string - tags_list?: string[] - sort_order?: number -} - -export interface ArticleListData { - list: KnowledgeArticleListItem[] - total: number - page: number - page_size: number -} - -export interface ArticleListParams { - page?: number - page_size?: number - title?: string - category_id?: number - author_id?: number - status?: 'draft' | 'published' | 'archived' - tags?: string -} - -export interface CreateArticlePayload { - title: string - content: string - summary?: string - image?: string - category: number - status?: 'draft' | 'published' - tags?: string - sort_order?: number -} - -export interface UpdateArticlePayload { - title?: string - content?: string - summary?: string - image?: string - category?: number - status?: 'draft' | 'published' | 'archived' - tags?: string - sort_order?: number -} - -// 知识库 API -export const knowledgeApi = { - // 分类:列表 - getCategories: (params?: CategoryListParams): Promise> => { - return request.get('/api/wiki/categories/', { params }) - }, - - // 分类:树形 - getCategoryTree: (): Promise> => { - return request.get('/api/wiki/categories/tree/') - }, - - // 分类:详情 - getCategory: (id: number): Promise> => { - return request.get(`/api/wiki/categories/${id}/`) - }, - - // 分类:创建 - createCategory: (data: Partial): Promise> => { - return request.post('/api/wiki/categories/create/', data) - }, - - // 分类:更新 - updateCategory: (id: number, data: Partial): Promise> => { - return request.put(`/api/wiki/categories/${id}/update/`, data) - }, - - // 分类:删除 - deleteCategory: (id: number): Promise> => { - return request.delete(`/api/wiki/categories/${id}/delete/`) - }, - - // 文章:列表 - getArticles: (params?: ArticleListParams): Promise> => { - return request.get('/api/wiki/articles/', { params }) - }, - - // 文章:详情 - getArticle: (id: number): Promise> => { - return request.get(`/api/wiki/articles/${id}/`) - }, - - // 文章:创建 - createArticle: (data: CreateArticlePayload): Promise> => { - return request.post('/api/wiki/articles/create/', data) - }, - - // 文章:更新 - updateArticle: (id: number, data: UpdateArticlePayload): Promise> => { - return request.put(`/api/wiki/articles/${id}/update/`, data) - }, - - // 文章:删除 - deleteArticle: (id: number): Promise> => { - return request.delete(`/api/wiki/articles/${id}/delete/`) - }, - - // 文章:发布 - publishArticle: (id: number): Promise> => { - return request.post(`/api/wiki/articles/${id}/publish/`) - }, - - // 文章:归档 - archiveArticle: (id: number): Promise> => { - return request.post(`/api/wiki/articles/${id}/archive/`) - }, -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/log.ts b/hertz_server_django_ui/src/api/log.ts deleted file mode 100644 index 1b3342b..0000000 --- a/hertz_server_django_ui/src/api/log.ts +++ /dev/null @@ -1,110 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 通用 API 响应结构 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 列表查询参数 -export interface LogListParams { - page?: number - page_size?: number - user_id?: number - operation_type?: string - operation_module?: string - start_date?: string // YYYY-MM-DD - end_date?: string // YYYY-MM-DD - ip_address?: string - status?: number - // 新增:按请求方法与路径、关键字筛选(与后端保持可选兼容) - request_method?: 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' | string - request_path?: string - keyword?: string -} - -// 列表项(精简字段) -export interface OperationLogItem { - id: number - user?: { - id: number - username: string - email?: string - } | null - operation_type: string - // 展示字段 - action_type_display?: string - operation_module: string - operation_description?: string - target_model?: string - target_object_id?: string - ip_address?: string - request_method: string - request_path: string - response_status: number - // 结果与状态展示 - status_display?: string - is_success?: boolean - execution_time?: number - created_at: string -} - -// 列表响应 data 结构 -export interface LogListData { - count: number - next: string | null - previous: string | null - results: OperationLogItem[] -} - -export type LogListResponse = ApiResponse - -// 详情数据(完整字段) -export interface OperationLogDetail { - id: number - user?: { - id: number - username: string - email?: string - } | null - operation_type: string - action_type_display?: string - operation_module: string - operation_description: string - target_model?: string - target_object_id?: string - ip_address?: string - user_agent?: string - request_method: string - request_path: string - request_data?: Record - response_status: number - status_display?: string - is_success?: boolean - response_data?: Record - execution_time?: number - created_at: string - updated_at?: string -} - -export type LogDetailResponse = ApiResponse - -export const logApi = { - // 获取操作日志列表 - getList: (params: LogListParams, options?: { signal?: AbortSignal }): Promise => { - // 关闭统一错误弹窗,由页面自行处理 - return request.get('/api/log/list/', { params, showError: false, signal: options?.signal }) - }, - - // 获取操作日志详情 - getDetail: (logId: number): Promise => { - return request.get(`/api/log/detail/${logId}/`) - }, - - // 兼容查询参数方式的详情(部分后端实现为 /api/log/detail/?id=xx 或 ?log_id=xx) - getDetailByQuery: (logId: number): Promise => { - return request.get('/api/log/detail/', { params: { id: logId, log_id: logId } }) - }, -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/menu.ts b/hertz_server_django_ui/src/api/menu.ts deleted file mode 100644 index 4a9fb43..0000000 --- a/hertz_server_django_ui/src/api/menu.ts +++ /dev/null @@ -1,361 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 后端返回的原始菜单数据格式 -export interface RawMenu { - menu_id: number - menu_name: string - menu_code: string - menu_type: number // 后端返回数字:1=菜单, 2=按钮, 3=接口 - parent_id?: number | null - path?: string - component?: string | null - icon?: string - permission?: string - sort_order?: number - description?: string - status?: number - is_external?: boolean - is_cache?: boolean - is_visible?: boolean - created_at?: string - updated_at?: string - children?: RawMenu[] -} - -// 前端使用的菜单接口类型定义 -export interface Menu { - menu_id: number - menu_name: string - menu_code: string - menu_type: number // 1=菜单, 2=按钮, 3=接口 - parent_id?: number - path?: string - component?: string - icon?: string - permission?: string - sort_order?: number - status?: number - is_external?: boolean - is_cache?: boolean - is_visible?: boolean - created_at?: string - updated_at?: string - children?: Menu[] -} - -// API响应基础结构 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 菜单列表数据结构 -export interface MenuListData { - list: Menu[] - total: number - page: number - page_size: number -} - -// 菜单列表响应类型 -export type MenuListResponse = ApiResponse - -// 菜单列表查询参数 -export interface MenuListParams { - page?: number - page_size?: number - search?: string - status?: number - menu_type?: string - parent_id?: number -} - -// 创建菜单参数 -export interface CreateMenuParams { - menu_name: string - menu_code: string - menu_type: number // 1=菜单, 2=按钮, 3=接口 - parent_id?: number - path?: string - component?: string - icon?: string - permission?: string - sort_order?: number - status?: number - is_external?: boolean - is_cache?: boolean - is_visible?: boolean -} - -// 更新菜单参数 -export type UpdateMenuParams = Partial - -// 菜单树响应类型 -export type MenuTreeResponse = ApiResponse - -// 数据转换工具函数 -const convertMenuType = (type: number): 'menu' | 'button' | 'api' => { - switch (type) { - case 1: return 'menu' - case 2: return 'button' - case 3: return 'api' - default: return 'menu' - } -} - -// 解码Unicode字符串 -const decodeUnicode = (str: string): string => { - try { - return str.replace(/\\u[\dA-F]{4}/gi, (match) => { - return String.fromCharCode(parseInt(match.replace(/\\u/g, ''), 16)) - }) - } catch (error) { - return str - } -} - -// 转换原始菜单数据为前端格式 -const transformRawMenu = (rawMenu: RawMenu): Menu => { - // 确保status字段被正确转换 - let statusValue: number - if (rawMenu.status === undefined || rawMenu.status === null) { - // 如果status缺失,默认为启用(1) - statusValue = 1 - } else { - // 如果有值,转换为数字 - if (typeof rawMenu.status === 'string') { - const parsed = parseInt(rawMenu.status, 10) - statusValue = isNaN(parsed) ? 1 : parsed - } else { - statusValue = Number(rawMenu.status) - // 如果转换失败,默认为启用 - if (isNaN(statusValue)) { - statusValue = 1 - } - } - } - - return { - menu_id: rawMenu.menu_id, - menu_name: decodeUnicode(rawMenu.menu_name), - menu_code: rawMenu.menu_code, - menu_type: rawMenu.menu_type, - parent_id: rawMenu.parent_id || undefined, - path: rawMenu.path, - component: rawMenu.component, - icon: rawMenu.icon, - permission: rawMenu.permission, - sort_order: rawMenu.sort_order, - status: statusValue, // 使用转换后的值 - is_external: rawMenu.is_external, - is_cache: rawMenu.is_cache, - is_visible: rawMenu.is_visible, - created_at: rawMenu.created_at, - updated_at: rawMenu.updated_at, - children: rawMenu.children ? rawMenu.children.map(transformRawMenu) : [] - } -} - -// 将菜单数据数组转换为列表格式 -const transformToMenuList = (rawMenus: RawMenu[]): MenuListData => { - const transformedMenus = rawMenus.map(transformRawMenu) - - // 递归收集所有菜单项 - const collectAllMenus = (menu: Menu): Menu[] => { - const result = [menu] - if (menu.children && menu.children.length > 0) { - menu.children.forEach(child => { - result.push(...collectAllMenus(child)) - }) - } - return result - } - - // 收集所有菜单项 - const allMenus: Menu[] = [] - transformedMenus.forEach(menu => { - allMenus.push(...collectAllMenus(menu)) - }) - - return { - list: allMenus, - total: allMenus.length, - page: 1, - page_size: allMenus.length - } -} - -// 构建菜单树结构 -const buildMenuTree = (rawMenus: RawMenu[]): Menu[] => { - const transformedMenus = rawMenus.map(transformRawMenu) - - // 创建菜单映射 - const menuMap = new Map() - transformedMenus.forEach(menu => { - menuMap.set(menu.menu_id, { ...menu, children: [] }) - }) - - // 构建树结构 - const rootMenus: Menu[] = [] - transformedMenus.forEach(menu => { - const menuItem = menuMap.get(menu.menu_id)! - - if (menu.parent_id && menuMap.has(menu.parent_id)) { - const parent = menuMap.get(menu.parent_id)! - if (!parent.children) parent.children = [] - parent.children.push(menuItem) - } else { - rootMenus.push(menuItem) - } - }) - - return rootMenus -} - -// 菜单API -export const menuApi = { - // 获取菜单列表 - getMenuList: async (params?: MenuListParams): Promise => { - try { - const response = await request.get>('/api/menus/', { params }) - - if (response.success && response.data && Array.isArray(response.data)) { - const menuListData = transformToMenuList(response.data) - return { - success: true, - code: response.code, - message: response.message, - data: menuListData - } - } - - return { - success: false, - code: response.code || 500, - message: response.message || '获取菜单数据失败', - data: { - list: [], - total: 0, - page: 1, - page_size: 10 - } - } - } catch (error) { - console.error('获取菜单列表失败:', error) - return { - success: false, - code: 500, - message: '网络请求失败', - data: { - list: [], - total: 0, - page: 1, - page_size: 10 - } - } - } - }, - - // 获取菜单树 - getMenuTree: async (): Promise => { - try { - const response = await request.get>('/api/menus/tree/') - - if (response.success && response.data && Array.isArray(response.data)) { - // 调试:检查原始数据中的status值 - if (response.data.length > 0) { - console.log('🔍 原始菜单数据status检查(前5条):', response.data.slice(0, 5).map((m: RawMenu) => ({ - menu_name: m.menu_name, - menu_id: m.menu_id, - status: m.status, - statusType: typeof m.status - }))) - } - - // 后端已经返回树形结构,直接转换数据格式即可 - const transformedData = response.data.map(transformRawMenu) - - // 调试:检查转换后的status值 - if (transformedData.length > 0) { - console.log('🔍 转换后菜单数据status检查(前5条):', transformedData.slice(0, 5).map((m: Menu) => ({ - menu_name: m.menu_name, - menu_id: m.menu_id, - status: m.status, - statusType: typeof m.status - }))) - } - - return { - success: true, - code: response.code, - message: response.message, - data: transformedData - } - } - - return { - success: false, - code: response.code || 500, - message: response.message || '获取菜单树失败', - data: [] - } - } catch (error) { - console.error('获取菜单树失败:', error) - return { - success: false, - code: 500, - message: '网络请求失败', - data: [] - } - } - }, - - // 获取单个菜单 - getMenu: async (id: number): Promise> => { - try { - const response = await request.get>(`/api/menus/${id}/`) - - if (response.success && response.data) { - const transformedMenu = transformRawMenu(response.data) - return { - success: true, - code: response.code, - message: response.message, - data: transformedMenu - } - } - - return response as ApiResponse - } catch (error) { - console.error('获取菜单详情失败:', error) - return { - success: false, - code: 500, - message: '网络请求失败', - data: {} as Menu - } - } - }, - - // 创建菜单 - createMenu: (data: CreateMenuParams): Promise> => { - return request.post('/api/menus/create/', data) - }, - - // 更新菜单 - updateMenu: (id: number, data: UpdateMenuParams): Promise> => { - return request.put(`/api/menus/${id}/update/`, data) - }, - - // 删除菜单 - deleteMenu: (id: number): Promise> => { - return request.delete(`/api/menus/${id}/delete/`) - }, - - // 批量删除菜单 - batchDeleteMenus: (ids: number[]): Promise> => { - return request.post('/api/menus/batch-delete/', { menu_ids: ids }) - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/notice_user.ts b/hertz_server_django_ui/src/api/notice_user.ts deleted file mode 100644 index b065b30..0000000 --- a/hertz_server_django_ui/src/api/notice_user.ts +++ /dev/null @@ -1,87 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 用户端通知模块 API 类型定义 -export interface UserNoticeListItem { - notice: number - title: string - notice_type_display: string - priority_display: string - is_top: boolean - publish_time: string - is_read: boolean - read_time: string | null - is_starred: boolean - starred_time: string | null - is_expired: boolean - created_at: string -} - -export interface UserNoticeListData { - notices: UserNoticeListItem[] - pagination: { - current_page: number - page_size: number - total_pages: number - total_count: number - has_next: boolean - has_previous: boolean - } - statistics: { - total_count: number - unread_count: number - starred_count: number - } -} - -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -export interface UserNoticeDetailData { - notice: number - title: string - content: string - notice_type_display: string - priority_display: string - attachment_url: string | null - publish_time: string - expire_time: string - is_top: boolean - is_expired: boolean - publisher_name: string | null - is_read: boolean - read_time: string - is_starred: boolean - starred_time: string | null - created_at: string - updated_at: string -} - -export const noticeUserApi = { - // 查看通知列表 - list: (params?: { page?: number; page_size?: number }): Promise> => - request.get('/api/notice/user/list/', { params }), - - // 查看通知详情 - detail: (notice_id: number | string): Promise> => - request.get(`/api/notice/user/detail/${notice_id}/`), - - // 标记通知已读 - markRead: (notice_id: number | string): Promise> => - request.post('/api/notice/user/mark-read/', { notice_id }), - - // 批量标记通知已读 - batchMarkRead: (notice_ids: Array): Promise> => - request.post('/api/notice/user/batch-mark-read/', { notice_ids }), - - // 用户获取通知统计 - statistics: (): Promise; priority_statistics?: Record }>> => - request.get('/api/notice/user/statistics/'), - - // 收藏/取消收藏通知 - toggleStar: (notice_id: number | string, is_starred: boolean): Promise> => - request.post('/api/notice/user/toggle-star/', { notice_id, is_starred }), -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/password.ts b/hertz_server_django_ui/src/api/password.ts deleted file mode 100644 index 1e6f08f..0000000 --- a/hertz_server_django_ui/src/api/password.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 修改密码接口参数 -export interface ChangePasswordParams { - old_password: string - new_password: string - confirm_password: string -} - -// 重置密码接口参数 -export interface ResetPasswordParams { - email: string - email_code: string - new_password: string - confirm_password: string -} - -// 修改密码 -export const changePassword = (params: ChangePasswordParams) => { - return request.post('/api/auth/password/change/', params) -} - -// 重置密码 -export const resetPassword = (params: ResetPasswordParams) => { - return request.post('/api/auth/password/reset/', params) -} - -// 发送重置密码邮箱验证码 -export const sendResetPasswordCode = (email: string) => { - return request.post('/api/auth/password/reset/code/', { email }) -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/role.ts b/hertz_server_django_ui/src/api/role.ts deleted file mode 100644 index 4075e00..0000000 --- a/hertz_server_django_ui/src/api/role.ts +++ /dev/null @@ -1,130 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 权限接口类型定义 -export interface Permission { - permission_id: number - permission_name: string - permission_code: string - permission_type: 'menu' | 'button' | 'api' - parent_id?: number - path?: string - icon?: string - sort_order?: number - description?: string - status?: number - children?: Permission[] -} - -// 角色接口类型定义 -export interface Role { - role_id: number - role_name: string - role_code: string - description?: string - status?: number - created_at?: string - updated_at?: string - permissions?: Permission[] -} - -// API响应基础结构 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 角色列表数据结构 -export interface RoleListData { - list: Role[] - total: number - page: number - page_size: number -} - -// 角色列表响应类型 -export type RoleListResponse = ApiResponse - -// 角色列表查询参数 -export interface RoleListParams { - page?: number - page_size?: number - search?: string - status?: number -} - -// 创建角色参数 -export interface CreateRoleParams { - role_name: string - role_code: string - description?: string - status?: number -} - -// 更新角色参数 -export type UpdateRoleParams = Partial - -// 角色权限分配参数 -export interface AssignRolePermissionsParams { - role_id: number - menu_ids: number[] - user_type?: number - department_id?: number -} - -// 权限列表响应类型 -export type PermissionListResponse = ApiResponse - -// 角色API -export const roleApi = { - // 获取角色列表 - getRoleList: (params?: RoleListParams): Promise => { - return request.get('/api/roles/', { params }) - }, - - // 获取单个角色 - getRole: (id: number): Promise> => { - return request.get(`/api/roles/${id}/`) - }, - - // 创建角色 - createRole: (data: CreateRoleParams): Promise> => { - return request.post('/api/roles/create/', data) - }, - - // 更新角色 - updateRole: (id: number, data: UpdateRoleParams): Promise> => { - return request.put(`/api/roles/${id}/update/`, data) - }, - - // 删除角色 - deleteRole: (id: number): Promise> => { - return request.delete(`/api/roles/${id}/delete/`) - }, - - // 批量删除角色 - batchDeleteRoles: (ids: number[]): Promise> => { - return request.post('/api/roles/batch-delete/', { role_ids: ids }) - }, - - // 获取角色权限 - getRolePermissions: (id: number): Promise> => { - return request.get(`/api/roles/${id}/menus/`) - }, - - // 分配角色权限 - assignRolePermissions: (data: AssignRolePermissionsParams): Promise> => { - return request.post(`/api/roles/assign-menus/`, data) - }, - - // 获取所有权限列表 - getPermissionList: (): Promise => { - return request.get('/api/menus/') - }, - - // 获取权限树 - getPermissionTree: (): Promise => { - return request.get('/api/menus/tree/') - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/system_monitor.ts b/hertz_server_django_ui/src/api/system_monitor.ts deleted file mode 100644 index 66e4560..0000000 --- a/hertz_server_django_ui/src/api/system_monitor.ts +++ /dev/null @@ -1,114 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 通用响应类型 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 1. 系统信息 -export interface SystemInfo { - hostname: string - platform: string - architecture: string - boot_time: string - uptime: string -} - -// 2. CPU 信息 -export interface CpuInfo { - cpu_count: number - cpu_percent: number - cpu_freq: { - current: number - min: number - max: number - } - load_avg: number[] -} - -// 3. 内存信息 -export interface MemoryInfo { - total: number - available: number - used: number - percent: number - free: number -} - -// 4. 磁盘信息 -export interface DiskInfo { - device: string - mountpoint: string - fstype: string - total: number - used: number - free: number - percent: number -} - -// 5. 网络信息 -export interface NetworkInfo { - interface: string - bytes_sent: number - bytes_recv: number - packets_sent: number - packets_recv: number -} - -// 6. 进程信息 -export interface ProcessInfo { - pid: number - name: string - status: string - cpu_percent: number - memory_percent: number - memory_info: { - rss: number - vms: number - } - create_time: string - cmdline: string[] -} - -// 7. GPU 信息 -export interface GpuInfoItem { - id: number - name: string - load: number - memory_total: number - memory_used: number - memory_util: number - temperature: number -} - -export interface GpuInfoResponse { - gpu_available: boolean - gpu_info?: GpuInfoItem[] - message?: string - timestamp: string -} - -// 8. 综合监测信息 -export interface MonitorData { - system: SystemInfo - cpu: CpuInfo - memory: MemoryInfo - disks: DiskInfo[] - network: NetworkInfo[] - processes: ProcessInfo[] - gpus: Array<{ gpu_available: boolean; message?: string; timestamp: string }> -} - -export const systemMonitorApi = { - getSystem: (): Promise> => request.get('/api/system/system/'), - getCpu: (): Promise> => request.get('/api/system/cpu/'), - getMemory: (): Promise> => request.get('/api/system/memory/'), - getDisks: (): Promise> => request.get('/api/system/disks/'), - getNetwork: (): Promise> => request.get('/api/system/network/'), - getProcesses: (): Promise> => request.get('/api/system/processes/'), - getGpu: (): Promise> => request.get('/api/system/gpu/'), - getMonitor: (): Promise> => request.get('/api/system/monitor/'), -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/user.ts b/hertz_server_django_ui/src/api/user.ts deleted file mode 100644 index 9ddf0f2..0000000 --- a/hertz_server_django_ui/src/api/user.ts +++ /dev/null @@ -1,121 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// 角色接口类型定义 -export interface Role { - role_id: number - role_name: string - role_code: string - role_ids?: string -} - -// 用户接口类型定义(匹配后端实际数据结构) -export interface User { - user_id: number - username: string - email: string - phone?: string - real_name?: string - avatar?: string - gender: number - birthday?: string - department_id?: number - status: number - last_login_time?: string - last_login_ip?: string - created_at: string - updated_at: string - roles: Role[] -} - -// API响应基础结构 -export interface ApiResponse { - success: boolean - code: number - message: string - data: T -} - -// 用户列表数据结构 -export interface UserListData { - list: User[] - total: number - page: number - page_size: number -} - -// 用户列表响应类型 -export type UserListResponse = ApiResponse - -// 用户列表查询参数 -export interface UserListParams { - page?: number - page_size?: number - search?: string - status?: number - role_ids?: string -} - -// 分配角色参数 -export interface AssignRolesParams { - user_id: number - role_ids: number[] // 角色ID数组 -} - -// 用户API -export const userApi = { - // 获取用户列表 - getUserList: (params?: UserListParams): Promise => { - return request.get('/api/users/', { params }) - }, - - // 获取单个用户 - getUser: (id: number): Promise> => { - return request.get(`/api/users/${id}/`) - }, - - // 创建用户 - createUser: (data: Partial): Promise> => { - return request.post('/api/users/create/', data) - }, - - // 更新用户 - updateUser: (id: number, data: Partial): Promise> => { - return request.put(`/api/users/${id}/update/`, data) - }, - - // 删除用户 - deleteUser: (id: number): Promise> => { - return request.delete(`/api/users/${id}/delete/`) - }, - - // 批量删除用户 - batchDeleteUsers: (ids: number[]): Promise> => { - return request.post('/api/admin/users/batch-delete/', { user_ids: ids }) - }, - - // 获取当前用户信息 - getUserInfo: (): Promise> => { - return request.get('/api/auth/user/info/') - }, - - // 更新当前用户信息 - updateUserInfo: (data: Partial): Promise> => { - return request.put('/api/auth/user/info/update/', data) - }, - - uploadAvatar: (file: File): Promise> => { - const formData = new FormData() - formData.append('avatar', file) - return request.upload('/api/auth/user/avatar/upload/', formData) - }, - - // 分配用户角色 - assignRoles: (data: AssignRolesParams): Promise> => { - return request.post('/api/users/assign-roles/', data) - }, - - // 获取所有角色列表 - getRoleList: (): Promise> => { - return request.get('/api/roles/') - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/api/yolo.ts b/hertz_server_django_ui/src/api/yolo.ts deleted file mode 100644 index d74139e..0000000 --- a/hertz_server_django_ui/src/api/yolo.ts +++ /dev/null @@ -1,643 +0,0 @@ -import { request } from '@/utils/hertz_request' - -// YOLO检测相关接口类型定义 -export interface YoloDetectionRequest { - image: File - model_id?: string - confidence_threshold?: number - nms_threshold?: number -} - -export interface DetectionBbox { - x: number - y: number - width: number - height: number -} - -export interface YoloDetection { - class_id: number - class_name: string - confidence: number - bbox: DetectionBbox -} - -export interface YoloDetectionResponse { - message: string - data?: { - detection_id: number - result_file_url: string - original_file_url: string - object_count: number - detected_categories: string[] - confidence_scores: number[] - avg_confidence: number | null - processing_time: number - model_used: string - confidence_threshold: number - user_id: number - user_name: string - alert_level?: 'low' | 'medium' | 'high' - } -} - -export interface YoloModel { - id: string - name: string - version: string - description: string - classes: string[] - is_active: boolean - is_enabled: boolean - model_file: string - model_folder_path: string - model_path: string - weights_folder_path: string - categories: { [key: string]: any } - created_at: string - updated_at: string -} - -export interface YoloModelListResponse { - success: boolean - message?: string - data?: { - models: YoloModel[] - total: number - } -} - -// 数据集管理相关类型 -export interface YoloDatasetSummary { - id: number - name: string - version?: string - root_folder_path: string - data_yaml_path: string - nc?: number - description?: string - created_at?: string -} - -export interface YoloDatasetDetail extends YoloDatasetSummary { - names?: string[] - train_images_count?: number - train_labels_count?: number - val_images_count?: number - val_labels_count?: number - test_images_count?: number - test_labels_count?: number -} - -export interface YoloDatasetSampleItem { - image: string - image_size?: number - label?: string - filename: string -} - -// YOLO 训练任务相关类型 -export type YoloTrainStatus = - | 'queued' - | 'running' - | 'canceling' - | 'completed' - | 'failed' - | 'canceled' - -export interface YoloTrainDatasetOption { - id: number - name: string - version?: string - yaml: string -} - -export interface YoloTrainVersionOption { - family: 'v8' | '11' | '12' - config_path: string - sizes: string[] -} - -export interface YoloTrainOptionsResponse { - success: boolean - code?: number - message?: string - data?: { - datasets: YoloTrainDatasetOption[] - versions: YoloTrainVersionOption[] - } -} - -export interface YoloTrainingJob { - id: number - dataset: number - dataset_name: string - model_family: 'v8' | '11' | '12' - model_size?: 'n' | 's' | 'm' | 'l' | 'x' - weight_path?: string - config_path?: string - status: YoloTrainStatus - logs_path?: string - runs_path?: string - best_model_path?: string - last_model_path?: string - progress: number - epochs: number - imgsz: number - batch: number - device: string - optimizer: 'SGD' | 'Adam' | 'AdamW' | 'RMSProp' - error_message?: string - created_at: string - started_at?: string | null - finished_at?: string | null -} - -export interface StartTrainingPayload { - dataset_id: number - model_family: 'v8' | '11' | '12' - model_size?: 'n' | 's' | 'm' | 'l' | 'x' - epochs?: number - imgsz?: number - batch?: number - device?: string - optimizer?: 'SGD' | 'Adam' | 'AdamW' | 'RMSProp' -} - -export interface YoloTrainLogsResponse { - success: boolean - code?: number - message?: string - data?: { - content: string - next_offset: number - finished: boolean - } -} - -// YOLO检测API -export const yoloApi = { - // 执行YOLO检测 - async detectImage(detectionRequest: YoloDetectionRequest): Promise { - console.log('🔍 构建检测请求:', detectionRequest) - console.log('📁 文件对象详情:', { - name: detectionRequest.image.name, - size: detectionRequest.image.size, - type: detectionRequest.image.type, - lastModified: detectionRequest.image.lastModified - }) - - const formData = new FormData() - formData.append('file', detectionRequest.image) - - if (detectionRequest.model_id) { - formData.append('model_id', detectionRequest.model_id) - } - if (detectionRequest.confidence_threshold) { - formData.append('confidence_threshold', detectionRequest.confidence_threshold.toString()) - } - if (detectionRequest.nms_threshold) { - formData.append('nms_threshold', detectionRequest.nms_threshold.toString()) - } - - // 调试FormData内容 - console.log('📤 FormData内容:') - for (const [key, value] of formData.entries()) { - if (value instanceof File) { - console.log(` ${key}: File(${value.name}, ${value.size} bytes, ${value.type})`) - } else { - console.log(` ${key}:`, value) - } - } - - return request.post('/api/yolo/detect/', formData, { - headers: { - 'Content-Type': 'multipart/form-data' - } - }) - }, - - // 获取当前启用的YOLO模型信息 - async getCurrentEnabledModel(): Promise<{ success: boolean; data?: YoloModel; message?: string }> { - // 关闭全局错误提示,由调用方(如 YOLO 检测页面)自行处理“未启用模型”等业务文案 - return request.get('/api/yolo/models/enabled/', { showError: false }) - }, - - // 获取模型详情 - async getModelInfo(modelId: string): Promise<{ success: boolean; data?: YoloModel; message?: string }> { - return request.get(`/api/yolo/models/${modelId}`) - }, - - // 批量检测 - async detectBatch(images: File[], modelId?: string): Promise { - const promises = images.map(image => - this.detectImage({ - image, - model_id: modelId, - confidence_threshold: 0.5, - nms_threshold: 0.4 - }) - ) - - return Promise.all(promises) - }, - - // 获取模型列表 - async getModels(): Promise<{ success: boolean; data?: YoloModel[]; message?: string }> { - return request.get('/api/yolo/models/') - }, - - // 上传模型 - async uploadModel(formData: FormData): Promise<{ success: boolean; message?: string }> { - // 使用专门的upload方法,它会自动处理Content-Type - return request.upload('/api/yolo/upload/', formData) - }, - - // 更新模型信息 - async updateModel(modelId: string, data: { name: string; version: string }): Promise<{ success: boolean; data?: YoloModel; message?: string }> { - return request.put(`/api/yolo/models/${modelId}/update/`, data) - }, - - // 删除模型 - async deleteModel(modelId: string): Promise<{ success: boolean; message?: string }> { - return request.delete(`/api/yolo/models/${modelId}/delete/`) - }, - - // 启用模型 - async enableModel(modelId: string): Promise<{ success: boolean; data?: YoloModel; message?: string }> { - return request.post(`/api/yolo/models/${modelId}/enable/`) - }, - - // 获取模型详情 - async getModelDetail(modelId: string): Promise<{ success: boolean; data?: YoloModel; message?: string }> { - return request.get(`/api/yolo/models/${modelId}/`) - }, - - // 获取检测历史记录列表 - async getDetectionHistory(params?: { - page?: number - page_size?: number - search?: string - start_date?: string - end_date?: string - model_id?: string - }): Promise<{ success: boolean; data?: DetectionHistoryRecord[]; message?: string }> { - return request.get('/api/yolo/detections/', { params }) - }, - - // 获取检测记录详情 - async getDetectionDetail(recordId: string): Promise<{ success: boolean; data?: DetectionHistoryRecord; message?: string }> { - return request.get(`/api/detections/${recordId}/`) - }, - - // 删除检测记录 - async deleteDetection(recordId: string): Promise<{ success: boolean; message?: string }> { - return request.delete(`/api/yolo/detections/${recordId}/delete/`) - }, - - // 批量删除检测记录 - async batchDeleteDetections(ids: number[]): Promise<{ success: boolean; message?: string }> { - return request.post('/api/yolo/detections/batch-delete/', { ids }) - }, - - // 获取检测统计 - async getDetectionStats(): Promise<{ success: boolean; data?: any; message?: string }> { - return request.get('/api/yolo/stats/') - }, - - // 数据集管理相关接口 - // 上传数据集 - async uploadDataset(formData: FormData): Promise<{ success: boolean; data?: YoloDatasetDetail; message?: string }> { - return request.upload('/api/yolo/datasets/upload/', formData) - }, - - // 获取数据集列表 - async getDatasets(): Promise<{ success: boolean; data?: YoloDatasetSummary[]; message?: string }> { - return request.get('/api/yolo/datasets/') - }, - - // 获取数据集详情 - async getDatasetDetail(datasetId: number): Promise<{ success: boolean; data?: YoloDatasetDetail; message?: string }> { - return request.get(`/api/yolo/datasets/${datasetId}/`) - }, - - // 删除数据集 - async deleteDataset(datasetId: number): Promise<{ success: boolean; message?: string }> { - return request.post(`/api/yolo/datasets/${datasetId}/delete/`) - }, - - // 获取数据集样本 - async getDatasetSamples( - datasetId: number, - params: { split?: 'train' | 'val' | 'test'; limit?: number; offset?: number } = {} - ): Promise<{ - success: boolean - data?: { items: YoloDatasetSampleItem[]; total: number } - message?: string - }> { - return request.get(`/api/yolo/datasets/${datasetId}/samples/`, { params }) - }, - - // YOLO 训练任务相关接口 - // 获取训练选项(可用数据集与模型版本) - async getTrainOptions(): Promise { - return request.get('/api/yolo/train/options/') - }, - - // 获取训练任务列表 - async getTrainJobs(): Promise<{ - success: boolean - code?: number - message?: string - data?: YoloTrainingJob[] - }> { - return request.get('/api/yolo/train/jobs/') - }, - - // 创建并启动训练任务 - async startTrainJob(payload: StartTrainingPayload): Promise<{ - success: boolean - code?: number - message?: string - data?: YoloTrainingJob - }> { - return request.post('/api/yolo/train/jobs/start/', payload) - }, - - // 获取训练任务详情 - async getTrainJobDetail(id: number): Promise<{ - success: boolean - code?: number - message?: string - data?: YoloTrainingJob - }> { - return request.get(`/api/yolo/train/jobs/${id}/`) - }, - - // 获取训练任务日志(分页读取) - async getTrainJobLogs( - id: number, - params: { offset?: number; max?: number } = {} - ): Promise { - return request.get(`/api/yolo/train/jobs/${id}/logs/`, { params }) - }, - - // 取消训练任务 - async cancelTrainJob(id: number): Promise<{ - success: boolean - code?: number - message?: string - data?: YoloTrainingJob - }> { - return request.post(`/api/yolo/train/jobs/${id}/cancel/`) - }, - - // 下载训练结果(ZIP) - async downloadTrainJobResult(id: number): Promise<{ - success: boolean - code?: number - message?: string - data?: { url: string; size: number } - }> { - return request.get(`/api/yolo/train/jobs/${id}/download/`) - }, - - // 删除训练任务 - async deleteTrainJob(id: number): Promise<{ - success: boolean - code?: number - message?: string - }> { - return request.post(`/api/yolo/train/jobs/${id}/delete/`) - }, - - // 警告等级管理相关接口 - // 获取警告等级列表 - async getAlertLevels(): Promise<{ success: boolean; data?: AlertLevel[]; message?: string }> { - return request.get('/api/yolo/categories/') - }, - - // 获取警告等级详情 - async getAlertLevelDetail(levelId: string): Promise<{ success: boolean; data?: AlertLevel; message?: string }> { - return request.get(`/api/yolo/categories/${levelId}/`) - }, - - // 更新警告等级 - async updateAlertLevel(levelId: string, data: { alert_level?: 'low' | 'medium' | 'high'; alias?: string }): Promise<{ success: boolean; data?: AlertLevel; message?: string }> { - return request.put(`/api/yolo/categories/${levelId}/update/`, data) - }, - - // 切换警告等级状态 - async toggleAlertLevelStatus(levelId: string): Promise<{ success: boolean; data?: AlertLevel; message?: string }> { - return request.post(`/api/yolo/categories/${levelId}/toggle-status/`) - }, - - // 获取活跃的警告等级列表 - async getActiveAlertLevels(): Promise<{ success: boolean; data?: AlertLevel[]; message?: string }> { - return request.get('/api/yolo/categories/active/') - }, - - // 上传并转换PT模型为ONNX格式 - async uploadAndConvertToOnnx(formData: FormData): Promise<{ - success: boolean - message?: string - data?: { - onnx_path?: string - onnx_url?: string - download_url?: string - onnx_relative_path?: string - file_name?: string - labels_download_url?: string - labels_relative_path?: string - classes?: string[] - } - }> { - // 适配后端 @views.py 中的 upload_pt_convert_onnx 实现 - // 统一走 /api/upload_pt_convert_onnx - // 按你的后端接口:/yolo/onnx/upload/ - // 注意带上结尾斜杠,避免 404 - return request.upload('/api/yolo/onnx/upload/', formData) - } -} - -// 警告等级管理相关接口 -export interface AlertLevel { - id: number - model: number - model_name: string - name: string - alias: string - display_name: string - category_id: number - alert_level: 'low' | 'medium' | 'high' - alert_level_display: string - is_active: boolean - // 前端编辑状态字段 - editingAlias?: boolean - tempAlias?: string -} - -// 用户检测历史相关接口 -export interface DetectionHistoryRecord { - id: number - user_id: number - original_filename: string - result_filename: string - original_file: string - result_file: string - detection_type: 'image' | 'video' - object_count: number - detected_categories: string[] - confidence_scores: number[] - avg_confidence: number | null - processing_time: number - model_name: string - model_info: any - created_at: string - confidence_threshold?: number // 置信度阈值(原始设置值) - // 为了兼容前端显示,添加计算字段 - filename?: string - image_url?: string - detections?: YoloDetection[] -} - -export interface DetectionHistoryParams { - page?: number - page_size?: number - search?: string - class_filter?: string - start_date?: string - end_date?: string - model_id?: string -} - -export interface DetectionHistoryResponse { - success?: boolean - message?: string - data?: { - records: DetectionHistoryRecord[] - total: number - page: number - page_size: number - } | DetectionHistoryRecord[] - // 支持直接返回数组的情况 - results?: DetectionHistoryRecord[] - count?: number - // 支持Django REST framework的分页格式 - next?: string - previous?: string -} - -// 用户检测历史API -export const detectionHistoryApi = { - // 获取用户检测历史 - async getUserDetectionHistory(userId: number, params: DetectionHistoryParams = {}): Promise { - return request.get('/api/yolo/detections/', { - params: { - user_id: userId, - ...params - } - }) - }, - - // 获取检测记录详情 - async getDetectionRecordDetail(recordId: number): Promise<{ - success?: boolean - code?: number - message?: string - data?: DetectionHistoryRecord - }> { - return request.get(`/api/yolo/detections/${recordId}/`) - }, - - // 删除检测记录 - async deleteDetectionRecord(userId: number, recordId: string): Promise<{ success: boolean; message?: string }> { - return request.delete(`/api/yolo/detections/${recordId}/delete/`) - }, - - // 批量删除检测记录 - async batchDeleteDetectionRecords(userId: number, recordIds: string[]): Promise<{ success: boolean; message?: string }> { - return request.post('/api/yolo/detections/batch-delete/', { ids: recordIds }) - }, - - // 导出检测历史 - async exportDetectionHistory(userId: number, params: DetectionHistoryParams = {}): Promise { - const response = await request.get('/api/yolo/detections/export/', { - params: { - user_id: userId, - ...params - }, - responseType: 'blob' - }) - return response - }, - - // 获取检测统计信息 - async getDetectionStats(userId: number): Promise<{ - success: boolean - data?: { - total_detections: number - total_images: number - class_counts: Record - recent_activity: Array<{ - date: string - count: number - }> - } - message?: string - }> { - return request.get('/api/yolo/detections/stats/', { - params: { user_id: userId } - }) - } -} - -// 告警相关接口类型定义 -export interface AlertRecord { - id: number - detection_record: number - detection_info: { - id: number - detection_type: string - original_filename: string - result_filename: string - object_count: number - avg_confidence: number - } - user: number - user_name: string - alert_level: string - alert_level_display: string - alert_category: string - category: number - category_info: { - id: number - name: string - alert_level: string - alert_level_display: string - } - status: string - created_at: string - deleted_at: string | null -} - -// 告警管理API -export const alertApi = { - // 获取所有告警记录 - async getAllAlerts(): Promise<{ success: boolean; data?: AlertRecord[]; message?: string }> { - return request.get('/api/yolo/alerts/') - }, - - // 获取当前用户的告警记录 - async getUserAlerts(userId: string): Promise<{ success: boolean; data?: AlertRecord[]; message?: string }> { - return request.get(`/api/yolo/users/${userId}/alerts/`) - }, - - // 处理告警(更新状态) - async updateAlertStatus(alertId: string, status: string): Promise<{ success: boolean; data?: AlertRecord; message?: string }> { - return request.put(`/api/yolo/alerts/${alertId}/update-status/`, { status }) - } -} - -// 默认导出 -export default yoloApi diff --git a/hertz_server_django_ui/src/config/hertz_modules.ts b/hertz_server_django_ui/src/config/hertz_modules.ts deleted file mode 100644 index 1d1ed92..0000000 --- a/hertz_server_django_ui/src/config/hertz_modules.ts +++ /dev/null @@ -1,85 +0,0 @@ -export type HertzModuleGroup = 'admin' | 'user' - -export interface HertzModule { - key: string - label: string - group: HertzModuleGroup - description?: string - defaultEnabled: boolean -} - -export const HERTZ_MODULES: HertzModule[] = [ - { key: 'admin.user-management', label: '管理端 · 用户管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.department-management', label: '管理端 · 部门管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.menu-management', label: '管理端 · 菜单管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.role-management', label: '管理端 · 角色管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.notification-management', label: '管理端 · 通知管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.log-management', label: '管理端 · 日志管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.knowledge-base', label: '管理端 · 文章管理', group: 'admin', defaultEnabled: true }, - { key: 'admin.yolo-model', label: '管理端 · YOLO 模型相关', group: 'admin', defaultEnabled: true }, - - { key: 'user.system-monitor', label: '用户端 · 系统监控', group: 'user', defaultEnabled: true }, - { key: 'user.ai-chat', label: '用户端 · AI 助手', group: 'user', defaultEnabled: true }, - { key: 'user.yolo-detection', label: '用户端 · YOLO 检测', group: 'user', defaultEnabled: true }, - { key: 'user.live-detection', label: '用户端 · 实时检测', group: 'user', defaultEnabled: true }, - { key: 'user.detection-history', label: '用户端 · 检测历史', group: 'user', defaultEnabled: true }, - { key: 'user.alert-center', label: '用户端 · 告警中心', group: 'user', defaultEnabled: true }, - { key: 'user.notice-center', label: '用户端 · 通知中心', group: 'user', defaultEnabled: true }, - { key: 'user.knowledge-center', label: '用户端 · 文章中心', group: 'user', defaultEnabled: true }, - { key: 'user.kb-center', label: '用户端 · 知识库中心', group: 'user', defaultEnabled: true }, -] - -const LOCAL_STORAGE_KEY = 'hertz_enabled_modules' - -export function getEnabledModuleKeys(): string[] { - const fallback = HERTZ_MODULES.filter(m => m.defaultEnabled).map(m => m.key) - - if (typeof window === 'undefined') { - return fallback - } - - try { - const stored = window.localStorage.getItem(LOCAL_STORAGE_KEY) - if (!stored) return fallback - const parsed = JSON.parse(stored) - if (Array.isArray(parsed)) { - const valid = parsed.filter((k): k is string => typeof k === 'string') - // 自动合并新增的默认启用模块,避免新模块在已有选择下被永久隐藏 - const missingDefaults = HERTZ_MODULES - .filter(m => m.defaultEnabled && !valid.includes(m.key)) - .map(m => m.key) - return [...valid, ...missingDefaults] - } - return fallback - } catch { - return fallback - } -} - -export function setEnabledModuleKeys(keys: string[]): void { - if (typeof window === 'undefined') return - try { - window.localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(keys)) - } catch { - // ignore - } -} - -export function isModuleEnabled(moduleKey?: string, enabledKeys?: string[]): boolean { - if (!moduleKey) return true - const keys = enabledKeys ?? getEnabledModuleKeys() - return keys.indexOf(moduleKey) !== -1 -} - -export function getModulesByGroup(group: HertzModuleGroup): HertzModule[] { - return HERTZ_MODULES.filter(m => m.group === group) -} - -export function hasModuleSelection(): boolean { - if (typeof window === 'undefined') return false - try { - return window.localStorage.getItem(LOCAL_STORAGE_KEY) !== null - } catch { - return false - } -} diff --git a/hertz_server_django_ui/src/locales/en-US.ts b/hertz_server_django_ui/src/locales/en-US.ts deleted file mode 100644 index 81390e0..0000000 --- a/hertz_server_django_ui/src/locales/en-US.ts +++ /dev/null @@ -1,159 +0,0 @@ -export default { - common: { - confirm: 'Confirm', - cancel: 'Cancel', - save: 'Save', - delete: 'Delete', - edit: 'Edit', - add: 'Add', - search: 'Search', - reset: 'Reset', - loading: 'Loading...', - noData: 'No Data', - success: 'Success', - error: 'Error', - warning: 'Warning', - info: 'Info', - }, - nav: { - home: 'Home', - dashboard: 'Dashboard', - user: 'User Management', - role: 'Role Management', - menu: 'Menu Management', - settings: 'System Settings', - profile: 'Profile', - logout: 'Logout', - }, - login: { - title: 'Login', - username: 'Username', - password: 'Password', - login: 'Login', - forgotPassword: 'Forgot Password?', - rememberMe: 'Remember Me', - }, - success: { - // General success messages - operationSuccess: 'Operation Successful', - saveSuccess: 'Save Successful', - deleteSuccess: 'Delete Successful', - updateSuccess: 'Update Successful', - - // Login and registration related success messages - loginSuccess: 'Login Successful', - registerSuccess: 'Registration Successful! Please Login', - logoutSuccess: 'Logout Successful', - emailCodeSent: 'Verification Code Sent to Your Email', - - // User management related success messages - userCreated: 'User Created Successfully', - userUpdated: 'User Information Updated Successfully', - userDeleted: 'User Deleted Successfully', - roleAssigned: 'Role Assigned Successfully', - - // Other operation success messages - uploadSuccess: 'File Upload Successful', - downloadSuccess: 'File Download Successful', - copySuccess: 'Copy Successful', - }, - error: { - // General errors - // 404: 'Page Not Found', - 403: 'Access Denied, Please Contact Administrator', - 500: 'Internal Server Error, Please Try Again Later', - networkError: 'Network Connection Failed, Please Check Network Settings', - timeout: 'Request Timeout, Please Try Again Later', - - // Login related errors - loginFailed: 'Login Failed, Please Check Username and Password', - usernameRequired: 'Please Enter Username', - passwordRequired: 'Please Enter Password', - captchaRequired: 'Please Enter Captcha', - captchaError: 'Captcha Error, Please Re-enter (Case Sensitive)', - captchaExpired: 'Captcha Expired, Please Refresh and Re-enter', - accountLocked: 'Account Locked, Please Contact Administrator', - accountDisabled: 'Account Disabled, Please Contact Administrator', - passwordExpired: 'Password Expired, Please Change Password', - loginAttemptsExceeded: 'Too Many Login Attempts, Account Temporarily Locked', - - // Registration related errors - registerFailed: 'Registration Failed, Please Check Input Information', - usernameExists: 'Username Already Exists, Please Choose Another', - emailExists: 'Email Already Registered, Please Use Another Email', - phoneExists: 'Phone Number Already Registered, Please Use Another', - emailFormatError: 'Invalid Email Format, Please Enter Valid Email', - phoneFormatError: 'Invalid Phone Format, Please Enter 11-digit Phone Number', - passwordTooWeak: 'Password Too Weak, Please Include Uppercase, Lowercase, Numbers and Special Characters', - passwordMismatch: 'Passwords Do Not Match', - emailCodeError: 'Email Verification Code Error or Expired', - emailCodeRequired: 'Please Enter Email Verification Code', - emailCodeLength: 'Verification Code Must Be 6 Digits', - emailRequired: 'Please Enter Email', - usernameLength: 'Username Length Must Be 3-20 Characters', - passwordLength: 'Password Length Must Be 6-20 Characters', - confirmPasswordRequired: 'Please Confirm Password', - phoneRequired: 'Please Enter Phone Number', - realNameRequired: 'Please Enter Real Name', - realNameLength: 'Name Length Must Be 2-10 Characters', - - // Permission related errors - accessDenied: 'Access Denied, You Do Not Have Permission to Perform This Action', - roleNotFound: 'Role Not Found or Deleted', - permissionDenied: 'Permission Denied, Cannot Perform This Action', - tokenExpired: 'Login Expired, Please Login Again', - tokenInvalid: 'Invalid Login Status, Please Login Again', - - // User management related errors - userNotFound: 'User Not Found or Deleted', - userCreateFailed: 'Failed to Create User, Please Check Input Information', - userUpdateFailed: 'Failed to Update User Information', - userDeleteFailed: 'Failed to Delete User, User May Be In Use', - cannotDeleteSelf: 'Cannot Delete Your Own Account', - cannotDeleteAdmin: 'Cannot Delete Administrator Account', - - // Department management related errors - departmentNotFound: 'Department Not Found or Deleted', - departmentNameExists: 'Department Name Already Exists', - departmentHasUsers: 'Department Has Users, Cannot Delete', - departmentCreateFailed: 'Failed to Create Department', - departmentUpdateFailed: 'Failed to Update Department Information', - departmentDeleteFailed: 'Failed to Delete Department', - - // Role management related errors - roleNameExists: 'Role Name Already Exists', - roleCreateFailed: 'Failed to Create Role', - roleUpdateFailed: 'Failed to Update Role Information', - roleDeleteFailed: 'Failed to Delete Role', - roleInUse: 'Role In Use, Cannot Delete', - - // File upload related errors - fileUploadFailed: 'File Upload Failed', - fileSizeExceeded: 'File Size Exceeded Limit', - fileTypeNotSupported: 'File Type Not Supported', - fileRequired: 'Please Select File to Upload', - - // Data validation related errors - invalidInput: 'Invalid Input Data Format', - requiredFieldMissing: 'Required Field Cannot Be Empty', - fieldTooLong: 'Input Content Exceeds Length Limit', - fieldTooShort: 'Input Content Length Insufficient', - invalidDate: 'Invalid Date Format', - invalidNumber: 'Invalid Number Format', - - // Operation related errors - operationFailed: 'Operation Failed, Please Try Again Later', - saveSuccess: 'Save Successful', - saveFailed: 'Save Failed, Please Check Input Information', - deleteSuccess: 'Delete Successful', - deleteFailed: 'Delete Failed, Please Try Again Later', - updateSuccess: 'Update Successful', - updateFailed: 'Update Failed, Please Check Input Information', - - // System related errors - systemMaintenance: 'System Under Maintenance, Please Visit Later', - serviceUnavailable: 'Service Temporarily Unavailable, Please Try Again Later', - databaseError: 'Database Connection Error, Please Contact Technical Support', - configError: 'System Configuration Error, Please Contact Administrator', - }, -} diff --git a/hertz_server_django_ui/src/locales/index.ts b/hertz_server_django_ui/src/locales/index.ts deleted file mode 100644 index 3250b2d..0000000 --- a/hertz_server_django_ui/src/locales/index.ts +++ /dev/null @@ -1,18 +0,0 @@ -import { createI18n } from 'vue-i18n' -import zhCN from './zh-CN' -import enUS from './en-US' - -const messages = { - 'zh-CN': zhCN, - 'en-US': enUS, -} - -export const i18n = createI18n({ - locale: 'zh-CN', - fallbackLocale: 'en-US', - messages, - legacy: false, - globalInjection: true, -}) - -export default i18n diff --git a/hertz_server_django_ui/src/locales/zh-CN.ts b/hertz_server_django_ui/src/locales/zh-CN.ts deleted file mode 100644 index b529d2b..0000000 --- a/hertz_server_django_ui/src/locales/zh-CN.ts +++ /dev/null @@ -1,172 +0,0 @@ -export default { - common: { - confirm: '确定', - cancel: '取消', - save: '保存', - delete: '删除', - edit: '编辑', - add: '添 加', - search: '搜索', - reset: '重置', - loading: '加载中...', - noData: '暂无数据', - success: '成功', - error: '错误', - warning: '警告', - info: '提示', - }, - nav: { - home: '首页', - dashboard: '仪表板', - user: '用户管理', - role: '角色管理', - menu: '菜单管理', - settings: '系统设置', - profile: '个人资料', - logout: '退出登录', - }, - login: { - title: '登录', - username: '用户名', - password: '密码', - login: '登录', - forgotPassword: '忘记密码?', - rememberMe: '记住我', - }, - register: { - title: '注册', - username: '用户名', - email: '邮箱', - password: '密码', - confirmPassword: '确认密码', - register: '注册', - agreement: '我已阅读并同意', - userAgreement: '用户协议', - privacyPolicy: '隐私政策', - hasAccount: '已有账号?', - goToLogin: '立即登录', - }, - success: { - // 通用成功提示 - operationSuccess: '操作成功', - saveSuccess: '保存成功', - deleteSuccess: '删除成功', - updateSuccess: '更新成功', - - // 登录注册相关成功提示 - loginSuccess: '登录成功', - registerSuccess: '注册成功!请前往登录', - logoutSuccess: '退出登录成功', - emailCodeSent: '验证码已发送到您的邮箱', - - // 用户管理相关成功提示 - userCreated: '用户创建成功', - userUpdated: '用户信息更新成功', - userDeleted: '用户删除成功', - roleAssigned: '角色分配成功', - - // 其他操作成功提示 - uploadSuccess: '文件上传成功', - downloadSuccess: '文件下载成功', - copySuccess: '复制成功', - }, - error: { - // 通用错误 - // 404: '页面未找到', - 403: '权限不足,请联系管理员', - 500: '服务器内部错误,请稍后重试', - networkError: '网络连接失败,请检查网络设置', - timeout: '请求超时,请稍后重试', - - // 登录相关错误 - loginFailed: '登录失败,请检查用户名和密码', - usernameRequired: '请输入用户名', - passwordRequired: '请输入密码', - captchaRequired: '请输入验证码', - captchaError: '验证码错误,请重新输入(区分大小写)', - captchaExpired: '验证码已过期,请刷新后重新输入', - accountLocked: '账户已被锁定,请联系管理员', - accountDisabled: '账户已被禁用,请联系管理员', - passwordExpired: '密码已过期,请修改密码', - loginAttemptsExceeded: '登录尝试次数过多,账户已被临时锁定', - - // 注册相关错误 - registerFailed: '注册失败,请检查输入信息', - usernameExists: '用户名已存在,请选择其他用户名', - emailExists: '邮箱已被注册,请使用其他邮箱', - phoneExists: '手机号已被注册,请使用其他手机号', - emailFormatError: '邮箱格式不正确,请输入有效的邮箱地址', - phoneFormatError: '手机号格式不正确,请输入11位手机号', - passwordTooWeak: '密码强度不足,请包含大小写字母、数字和特殊字符', - passwordMismatch: '两次输入的密码不一致', - emailCodeError: '邮箱验证码错误或已过期', - emailCodeRequired: '请输入邮箱验证码', - emailCodeLength: '验证码长度为6位', - emailRequired: '请输入邮箱', - usernameLength: '用户名长度为3-20个字符', - passwordLength: '密码长度为6-20个字符', - confirmPasswordRequired: '请确认密码', - phoneRequired: '请输入手机号', - realNameRequired: '请输入真实姓名', - realNameLength: '姓名长度为2-10个字符', - - // 权限相关错误 - accessDenied: '访问被拒绝,您没有执行此操作的权限', - roleNotFound: '角色不存在或已被删除', - permissionDenied: '权限不足,无法执行此操作', - tokenExpired: '登录已过期,请重新登录', - tokenInvalid: '登录状态无效,请重新登录', - - // 用户管理相关错误 - userNotFound: '用户不存在或已被删除', - userCreateFailed: '创建用户失败,请检查输入信息', - userUpdateFailed: '更新用户信息失败', - userDeleteFailed: '删除用户失败,该用户可能正在使用中', - cannotDeleteSelf: '不能删除自己的账户', - cannotDeleteAdmin: '不能删除管理员账户', - - // 部门管理相关错误 - departmentNotFound: '部门不存在或已被删除', - departmentNameExists: '部门名称已存在', - departmentHasUsers: '部门下还有用户,无法删除', - departmentCreateFailed: '创建部门失败', - departmentUpdateFailed: '更新部门信息失败', - departmentDeleteFailed: '删除部门失败', - - // 角色管理相关错误 - roleNameExists: '角色名称已存在', - roleCreateFailed: '创建角色失败', - roleUpdateFailed: '更新角色信息失败', - roleDeleteFailed: '删除角色失败', - roleInUse: '角色正在使用中,无法删除', - - // 文件上传相关错误 - fileUploadFailed: '文件上传失败', - fileSizeExceeded: '文件大小超出限制', - fileTypeNotSupported: '不支持的文件类型', - fileRequired: '请选择要上传的文件', - - // 数据验证相关错误 - invalidInput: '输入数据格式不正确', - requiredFieldMissing: '必填字段不能为空', - fieldTooLong: '输入内容超出长度限制', - fieldTooShort: '输入内容长度不足', - invalidDate: '日期格式不正确', - invalidNumber: '数字格式不正确', - - // 操作相关错误 - operationFailed: '操作失败,请稍后重试', - saveSuccess: '保存成功', - saveFailed: '保存失败,请检查输入信息', - deleteSuccess: '删除成功', - deleteFailed: '删除失败,请稍后重试', - updateSuccess: '更新成功', - updateFailed: '更新失败,请检查输入信息', - - // 系统相关错误 - systemMaintenance: '系统正在维护中,请稍后访问', - serviceUnavailable: '服务暂时不可用,请稍后重试', - databaseError: '数据库连接错误,请联系技术支持', - configError: '系统配置错误,请联系管理员', - }, -} diff --git a/hertz_server_django_ui/src/main.ts b/hertz_server_django_ui/src/main.ts deleted file mode 100644 index 34add5f..0000000 --- a/hertz_server_django_ui/src/main.ts +++ /dev/null @@ -1,47 +0,0 @@ -import { createApp } from 'vue' -import { createPinia } from 'pinia' -import App from './App.vue' -import router from './router' -import { i18n } from './locales' -import { checkEnvironmentVariables, validateEnvironment } from './utils/hertz_env' -import './styles/index.scss' - -// 导入Ant Design Vue -import 'ant-design-vue/dist/antd.css' - -// 开发环境检查 -if (import.meta.env.DEV) { - checkEnvironmentVariables() - validateEnvironment() -} - -// 创建Vue应用实例 -const app = createApp(App) - -// 使用Pinia状态管理 -const pinia = createPinia() -app.use(pinia) - -// 使用路由 -app.use(router) - -// 使用国际化 -app.use(i18n) - -// 初始化应用设置 -import { useAppStore } from './stores/hertz_app' -const appStore = useAppStore() -appStore.initAppSettings() - -// 检查用户认证状态 -import { useUserStore } from './stores/hertz_user' -const userStore = useUserStore() -userStore.checkAuth() - -// 初始化主题(必须在挂载前加载) -import { useThemeStore } from './stores/hertz_theme' -const themeStore = useThemeStore() -themeStore.loadTheme() - -// 挂载应用 -app.mount('#app') diff --git a/hertz_server_django_ui/src/router/admin_menu.ts b/hertz_server_django_ui/src/router/admin_menu.ts deleted file mode 100644 index e9235aa..0000000 --- a/hertz_server_django_ui/src/router/admin_menu.ts +++ /dev/null @@ -1,459 +0,0 @@ -import type { RouteRecordRaw } from "vue-router"; -import { getEnabledModuleKeys, isModuleEnabled } from "@/config/hertz_modules"; - -// 角色权限枚举 -export enum UserRole { - ADMIN = 'admin', - SYSTEM_ADMIN = 'system_admin', - NORMAL_USER = 'normal_user', - SUPER_ADMIN = 'super_admin' -} - -// 统一菜单配置接口 - 只需要在这里配置一次 -export interface AdminMenuItem { - key: string; // 菜单唯一标识 - title: string; // 菜单标题 - icon?: string; // 菜单图标 - path: string; // 路由路径 - component: string; // 组件路径(相对于@/views/admin_page/) - isDefault?: boolean; // 是否为默认路由(首页) - roles?: UserRole[]; // 允许访问的角色,不设置则使用默认管理员角色 - permission?: string; // 所需权限标识符 - children?: AdminMenuItem[]; // 子菜单 - moduleKey?: string; -} - -// 🎯 统一配置中心 - 只需要在这里修改菜单配置 -export const ADMIN_MENU_CONFIG: AdminMenuItem[] = [ - { - key: "dashboard", - title: "仪表盘", - icon: "DashboardOutlined", - path: "/admin", - component: "Dashboard.vue", - isDefault: true, // 标记为默认首页 - }, - { - key: "user-management", - title: "用户管理", - icon: "UserOutlined", - path: "/admin/user-management", - component: "UserManagement.vue", - permission: "system:user:list", // 需要用户列表权限 - moduleKey: "admin.user-management", - }, - { - key: "department-management", - title: "部门管理", - icon: "SettingOutlined", - path: "/admin/department-management", - component: "DepartmentManagement.vue", - permission: "system:dept:list", // 需要部门列表权限 - moduleKey: "admin.department-management", - }, - { - key: "menu-management", - title: "菜单管理", - icon: "SettingOutlined", - path: "/admin/menu-management", - component: "MenuManagement.vue", - permission: "system:menu:list", // 需要菜单列表权限 - moduleKey: "admin.menu-management", - }, - { - key: "teacher", - title: "角色管理", - icon: "UserOutlined", - path: "/admin/teacher", - component: "Role.vue", - permission: "system:role:list", // 需要角色列表权限 - moduleKey: "admin.role-management", - }, - { - key: "notification-management", - title: "通知管理", - icon: "UserOutlined", - path: "/admin/notification-management", - component: "NotificationManagement.vue", - permission: "studio:notice:list", // 需要通知列表权限 - moduleKey: "admin.notification-management", - }, - { - key: "log-management", - title: "日志管理", - icon: "FileSearchOutlined", - path: "/admin/log-management", - component: "LogManagement.vue", - permission: "log.view_operationlog", // 查看操作日志权限 - moduleKey: "admin.log-management", - }, - { - key: "knowledge-base", - title: "文章管理", - icon: "DatabaseOutlined", - path: "/admin/article-management", - component: "ArticleManagement.vue", - // 菜单访问权限:需要具备文章列表权限 - permission: "system:knowledge:article:list", - moduleKey: "admin.knowledge-base", - }, - { - key: "yolo-model", - title: "YOLO模型", - icon: "ClusterOutlined", - path: "/admin/yolo-model", - component: "ModelManagement.vue", // 默认显示模型管理页面 - // 父菜单不设置权限,由子菜单的权限决定是否显示 - moduleKey: "admin.yolo-model", - children: [ - { - key: "model-management", - title: "模型管理", - icon: "RobotOutlined", - path: "/admin/model-management", - component: "ModelManagement.vue", - permission: "system:yolo:model:list", - }, - { - key: "dataset-management", - title: "数据集管理", - icon: "DatabaseOutlined", - path: "/admin/dataset-management", - component: "DatasetManagement.vue", - }, - { - key: "yolo-train-management", - title: "YOLO训练", - icon: "HistoryOutlined", - path: "/admin/yolo-train", - component: "YoloTrainManagement.vue", - }, - { - key: "alert-level-management", - title: "模型类别管理", - icon: "WarningOutlined", - path: "/admin/alert-level-management", - component: "AlertLevelManagement.vue", - permission: "system:yolo:alert:list", - }, - { - key: "alert-processing-center", - title: "告警处理中心", - icon: "BellOutlined", - path: "/admin/alert-processing-center", - component: "AlertProcessingCenter.vue", - permission: "system:yolo:alert:process", - }, - { - key: "detection-history-management", - title: "检测历史管理", - icon: "HistoryOutlined", - path: "/admin/detection-history-management", - component: "DetectionHistoryManagement.vue", - permission: "system:yolo:history:list", - }, - ], - }, -]; - -// 默认管理员角色 - 修改为空数组,通过自定义权限检查函数处理 -const DEFAULT_ADMIN_ROLES: UserRole[] = []; - -// 组件映射 - 静态导入以支持Vite分析 -const COMPONENT_MAP: { [key: string]: () => Promise } = { - 'Dashboard.vue': () => import("@/views/admin_page/Dashboard.vue"), - 'UserManagement.vue': () => import("@/views/admin_page/UserManagement.vue"), - 'DepartmentManagement.vue': () => import("@/views/admin_page/DepartmentManagement.vue"), - 'Role.vue': () => import("@/views/admin_page/Role.vue"), - 'MenuManagement.vue': () => import("@/views/admin_page/MenuManagement.vue"), - 'NotificationManagement.vue': () => import("@/views/admin_page/NotificationManagement.vue"), - 'LogManagement.vue': () => import("@/views/admin_page/LogManagement.vue"), - 'ArticleManagement.vue': () => import("@/views/admin_page/ArticleManagement.vue"), - 'ModelManagement.vue': () => import("@/views/admin_page/ModelManagement.vue"), - 'DatasetManagement.vue': () => import("@/views/admin_page/DatasetManagement.vue"), - 'YoloTrainManagement.vue': () => import("@/views/admin_page/YoloTrainManagement.vue"), - 'AlertLevelManagement.vue': () => import("@/views/admin_page/AlertLevelManagement.vue"), - 'AlertProcessingCenter.vue': () => import("@/views/admin_page/AlertProcessingCenter.vue"), - 'DetectionHistoryManagement.vue': () => import("@/views/admin_page/DetectionHistoryManagement.vue"), -}; - -// 🚀 自动生成路由配置 -function generateAdminRoutes(): RouteRecordRaw { - const children: RouteRecordRaw[] = []; - const enabledModuleKeys = getEnabledModuleKeys(); - - ADMIN_MENU_CONFIG.forEach(item => { - if (!isModuleEnabled(item.moduleKey, enabledModuleKeys)) { - return; - } - // 如果有子菜单,将子菜单作为独立的路由项 - if (item.children && item.children.length > 0) { - // 为每个子菜单创建独立的路由 - item.children.forEach(child => { - children.push({ - path: child.path.replace("/admin/", ""), - name: child.key, - component: COMPONENT_MAP[child.component] || (() => import("@/views/admin_page/Dashboard.vue")), - meta: { - title: child.title, - requiresAuth: true, - roles: child.roles || DEFAULT_ADMIN_ROLES, - }, - }); - }); - } else { - // 没有子菜单的普通菜单项 - children.push({ - path: item.isDefault ? "" : item.path.replace("/admin/", ""), - name: item.key, - component: COMPONENT_MAP[item.component] || (() => import("@/views/admin_page/Dashboard.vue")), - meta: { - title: item.title, - requiresAuth: true, - roles: item.roles || DEFAULT_ADMIN_ROLES, - }, - }); - } - }); - - console.log('🛣️ 生成的管理端路由配置:', children.map(child => ({ - path: child.path, - name: child.name, - title: child.meta?.title - }))); - - return { - path: "/admin", - name: "Admin", - component: () => import("@/views/admin_page/index.vue"), - meta: { - title: "管理后台", - requiresAuth: true, - roles: DEFAULT_ADMIN_ROLES, - }, - children, - }; -} - -// 🚀 自动生成菜单配置 -export interface MenuConfig { - key: string; - title: string; - icon?: string; - path: string; - children?: MenuConfig[]; -} - -function generateMenuConfig(): MenuConfig[] { - return ADMIN_MENU_CONFIG.map(item => ({ - key: item.key, - title: item.title, - icon: item.icon, - path: item.path, - children: item.children?.map(child => ({ - key: child.key, - title: child.title, - icon: child.icon, - path: child.path, - })), - })); -} - -// 🚀 自动生成路径映射函数 -function generatePathKeyMapping(): { [path: string]: string } { - const mapping: { [path: string]: string } = {}; - - function addToMapping(items: AdminMenuItem[], parentPath = '') { - items.forEach(item => { - mapping[item.path] = item.key; - if (item.children) { - addToMapping(item.children, item.path); - } - }); - } - - addToMapping(ADMIN_MENU_CONFIG); - return mapping; -} - -// 导出的配置和函数 -export const adminMenuRoutes: RouteRecordRaw = generateAdminRoutes(); -export const adminMenuConfig: MenuConfig[] = generateMenuConfig(); - -// 路径到key的映射 -const pathKeyMapping = generatePathKeyMapping(); - -// 🎯 根据路径获取菜单key - 自动生成 -export const getMenuKeyByPath = (path: string): string => { - // 精确匹配 - if (pathKeyMapping[path]) { - return pathKeyMapping[path]; - } - - // 模糊匹配 - for (const [mappedPath, key] of Object.entries(pathKeyMapping)) { - if (path.includes(mappedPath) && mappedPath !== '/admin') { - return key; - } - } - - // 默认返回dashboard - return 'dashboard'; -}; - -// 🎯 根据菜单key获取路径 - 自动生成 -export const getPathByMenuKey = (key: string): string => { - console.log('🔍 查找菜单路径:', key); - - const menuItem = ADMIN_MENU_CONFIG.find(item => item.key === key); - if (menuItem) { - console.log('✅ 找到父菜单路径:', menuItem.path); - return menuItem.path; - } - - // 在子菜单中查找 - for (const item of ADMIN_MENU_CONFIG) { - if (item.children) { - const childItem = item.children.find(child => child.key === key); - if (childItem) { - console.log('✅ 找到子菜单路径:', childItem.path); - return childItem.path; - } - } - } - - console.log('❌ 未找到菜单路径,返回默认路径'); - return '/admin'; -}; - -// 🎯 根据菜单key获取标题 - 自动生成 -export const getTitleByMenuKey = (key: string): string => { - const menuItem = ADMIN_MENU_CONFIG.find(item => item.key === key); - if (menuItem) return menuItem.title; - - // 在子菜单中查找 - for (const item of ADMIN_MENU_CONFIG) { - if (item.children) { - const childItem = item.children.find(child => child.key === key); - if (childItem) return childItem.title; - } - } - - return '仪表盘'; -}; - -// 菜单权限检查 -export const hasMenuPermission = (menuKey: string, userRole: string): boolean => { - const menuItem = ADMIN_MENU_CONFIG.find(item => item.key === menuKey); - if (!menuItem) return false; - - return menuItem.roles ? menuItem.roles.includes(userRole as UserRole) : DEFAULT_ADMIN_ROLES.includes(userRole as UserRole); -}; - -// 🎯 新增:根据用户权限过滤菜单配置 -export const getFilteredMenuConfig = (userRoles: string[], userPermissions: string[], userMenuPermissions?: number[]): MenuConfig[] => { - const userRole = userRoles[0]; // 取第一个角色作为主要角色 - - // 仅管理员角色显示管理端菜单 - const adminRoles = ['admin', 'system_admin', 'super_admin']; - const isAdminRole = userRoles.some(r => adminRoles.includes(r)); - if (!isAdminRole) { - return []; - } - - // 对 super_admin / system_admin 开放所有管理菜单(忽略权限字符串过滤) - const isPrivilegedAdmin = userRoles.includes('super_admin') || userRoles.includes('system_admin'); - - const enabledModuleKeys = getEnabledModuleKeys(); - - // 过滤菜单项 - 基于模块开关和权限字符串检查 - const filteredMenus = ADMIN_MENU_CONFIG.filter(menuItem => { - if (!isModuleEnabled(menuItem.moduleKey, enabledModuleKeys)) { - return false; - } - console.log(`🔍 检查菜单项: ${menuItem.title} (${menuItem.key})`, { - hasPermission: !!menuItem.permission, - permission: menuItem.permission, - hasChildren: !!(menuItem.children && menuItem.children.length > 0), - childrenCount: menuItem.children?.length || 0 - }); - - // 如果菜单没有配置权限要求,则默认允许访问(如仪表盘) - if (!menuItem.permission) { - console.log(`✅ 菜单 ${menuItem.title} 无权限要求,允许访问`); - return true; - } - - // 检查用户是否有该菜单所需的权限 - const hasMenuPermission = isPrivilegedAdmin ? true : hasPermission(menuItem.permission, userPermissions); - - if (!hasMenuPermission) { - console.log(`❌ 菜单 ${menuItem.title} 权限不足,拒绝访问`); - return false; - } - - // 如果有子菜单,过滤子菜单 - if (menuItem.children && menuItem.children.length > 0) { - const filteredChildren = menuItem.children.filter(child => { - // 如果子菜单没有配置权限要求,则默认允许访问 - if (!child.permission) { - console.log(`✅ 子菜单 ${child.title} 无权限要求,允许访问`); - return true; - } - - const childHasPermission = hasPermission(child.permission, userPermissions); - console.log(`🔍 子菜单 ${child.title} 权限检查:`, { - permission: child.permission, - hasPermission: childHasPermission - }); - return childHasPermission; - }); - - console.log(`📊 菜单 ${menuItem.title} 子菜单过滤结果:`, { - originalCount: menuItem.children.length, - filteredCount: filteredChildren.length, - filteredChildren: filteredChildren.map(c => c.title) - }); - - // 如果没有任何子菜单有权限,则不显示父菜单 - if (filteredChildren.length === 0) { - console.log(`❌ 菜单 ${menuItem.title} 所有子菜单都无权限,隐藏父菜单`); - return false; - } - - // 更新子菜单列表 - menuItem.children = filteredChildren; - } - - console.log(`✅ 菜单 ${menuItem.title} 通过权限检查`); - return true; - }).map(menuItem => ({ - key: menuItem.key, - title: menuItem.title, - icon: menuItem.icon, - path: menuItem.path, - children: menuItem.children?.map(child => ({ - key: child.key, - title: child.title, - icon: child.icon, - path: child.path - })) - })); - - return filteredMenus; -}; - -// 🎯 新增:检查用户是否有任何管理员菜单权限 -// 修改逻辑:只有normal_user角色不能访问管理端,其他所有角色都可以访问 -export const hasAnyAdminPermission = (userRoles: string[]): boolean => { - // 仅当包含 admin/system_admin/super_admin 之一才视为管理员 - const adminRoles = ['admin', 'system_admin', 'super_admin']; - return userRoles.some(role => adminRoles.includes(role)); -}; - -/** - * 检查用户是否有指定权限 - */ -const hasPermission = (permission: string, userPermissions: string[]): boolean => { - return userPermissions.includes(permission); -}; diff --git a/hertz_server_django_ui/src/router/index.ts b/hertz_server_django_ui/src/router/index.ts deleted file mode 100644 index b9f8549..0000000 --- a/hertz_server_django_ui/src/router/index.ts +++ /dev/null @@ -1,295 +0,0 @@ -import { createRouter, createWebHistory } from "vue-router"; -import type { RouteRecordRaw } from "vue-router"; -import { useUserStore } from "@/stores/hertz_user"; -import { adminMenuRoutes, UserRole } from "./admin_menu"; -import { userRoutes } from "./user_menu_ai"; -import { hasModuleSelection } from "@/config/hertz_modules"; - -// 固定路由配置 -const fixedRoutes: RouteRecordRaw[] = [ - { - path: "/", - name: "Home", - component: () => import("@/views/Home.vue"), - meta: { - title: "首页", - requiresAuth: false, - }, - children: [...generateDynamicRoutes("public")], - }, - { - path: "/login", - name: "Login", - component: () => import("@/views/Login.vue"), - meta: { - title: "登录", - requiresAuth: false, - }, - }, - { - path: "/template/modules", - name: "ModuleSetup", - component: () => import("@/views/ModuleSetup.vue"), - meta: { - title: "模块配置", - requiresAuth: false, - }, - }, - { - path: "/register", - name: "Register", - component: () => import("@/views/register.vue"), - meta: { - title: "注册", - requiresAuth: false, - }, - }, - // 管理端路由 - 从admin_menu.ts导入 - adminMenuRoutes, -]; - -// 动态生成路由配置 -function generateDynamicRoutes(targetDir: string = ""): RouteRecordRaw[] { - if (!targetDir) { - return []; - } - const viewsContext = import.meta.glob("@/views/**/*.vue", { eager: true }); - - return Object.entries(viewsContext) - .map(([path, component]) => { - const relativePath = path.match(/\/views\/(.+?)\.vue$/)?.[1]; - if (!relativePath) return null; - - const fileName = relativePath.replace(".vue", ""); - const routeName = fileName.split("/").pop()!; - - // 过滤条件 - if (targetDir && !fileName.startsWith(targetDir)) { - return null; - } - - // 生成路径和标题 - const routePath = `/${fileName.replace(/([A-Z])/g, "$1").toLowerCase()}`; - const requiresAuth = - (!routePath.startsWith("/demo") && !routePath.startsWith("/public")) || routePath.startsWith("/user_pages")&& routePath.startsWith("/admin_page"); - const pageTitle = (component as any)?.default?.title; - - // 根据路径设置角色权限 - let roles: UserRole[] = []; - if (routePath.startsWith("/admin_page")) { - roles = [UserRole.ADMIN, UserRole.SYSTEM_ADMIN, UserRole.SUPER_ADMIN]; - } else if (routePath.startsWith("/user_pages")) { - roles = [UserRole.NORMAL_USER, UserRole.ADMIN, UserRole.SYSTEM_ADMIN, UserRole.SUPER_ADMIN]; - } else if (routePath.startsWith("/demo")) { - roles = []; // demo页面不需要特定角色 - } - - return { - path: routePath, - name: routeName, - component: () => import(/* @vite-ignore */ path), - meta: { - title: pageTitle, - requiresAuth, - roles: requiresAuth ? roles : [] - }, - }; - }) - .filter(Boolean) as RouteRecordRaw[]; -} - -// 合并固定路由和动态路由 -const routes: RouteRecordRaw[] = [ - ...fixedRoutes, - ...userRoutes, // 用户菜单路由 - 现在通过统一配置自动生成 - ...generateDynamicRoutes("demo"), // 生成demo文件夹的路由 - ...generateDynamicRoutes("admin_page"),//生成admin_page文件夹的路由 - // 404页面始终放在最后 - { - path: "/:pathMatch(.*)*", - name: "NotFound", - component: () => import("@/views/NotFound.vue"), - meta: { - title: "页面未找到", - requiresAuth: false, - }, - }, -]; - -// 创建路由实例 -const router = createRouter({ - history: createWebHistory(), - routes, - scrollBehavior(_to, _from, savedPosition) { - if (savedPosition) { - return savedPosition; - } else { - return { top: 0 }; - } - }, -}); - -// 递归打印路由信息 -function printRoute(route: RouteRecordRaw, level: number = 0) { - const indent = " ".repeat(level); - const icon = route.meta.requiresAuth ? "🔒" : "🔓"; - const auth = route.meta.requiresAuth ? "需要登录" : "公开访问"; - console.log(`${indent}${icon} ${route.path} → ${route.meta.title} (${auth})`); - - // 递归打印子路由 - if (route.children && route.children.length > 0) { - route.children.forEach((child) => printRoute(child, level + 1)); - } -} - -// 路由调试信息 -function logRouteInfo() { - console.log("🚀 管理系统 路由配置:"); - console.log("📋 路由列表:"); - - routes.forEach((route) => printRoute(route)); - - console.log(" ❓ /:pathMatch(.*)* → NotFound (页面未找到)"); - console.log("✅ 路由配置完成!"); -} - -// 重定向计数器,防止无限重定向 -let redirectCount = 0; -const MAX_REDIRECTS = 3; - -// 路由守卫 -router.beforeEach((to, _from, next) => { - const userStore = useUserStore(); - - // 调试信息 - console.log('🛡️ 路由守卫检查'); - console.log('📍 目标路由:', to.path, to.name); - console.log('🔐 需要认证:', to.meta.requiresAuth); - console.log('👤 用户登录状态:', userStore.isLoggedIn); - console.log('🎫 Token:', userStore.token ? '存在' : '不存在'); - console.log('📋 用户信息:', userStore.userInfo); - console.log('🔄 重定向计数:', redirectCount); - - // 模板模式:首次必须先完成模块选择 - const isTemplateMode = import.meta.env.VITE_TEMPLATE_SETUP_MODE === 'true'; - if (isTemplateMode && to.name !== "ModuleSetup") { - if (!hasModuleSelection()) { - console.log('🧩 模板模式开启,尚未选择模块,重定向到模块配置页'); - next({ name: "ModuleSetup", query: { redirect: to.fullPath } }); - return; - } - } - - // 设置页面标题 - if (to.meta.title) { - document.title = `${to.meta.title} - 管理系统`; - } - - // 检查是否需要登录 - if (to.meta.requiresAuth && !userStore.isLoggedIn) { - console.log('❌ 需要登录但用户未登录,重定向到登录页'); - redirectCount++; - if (redirectCount > MAX_REDIRECTS) { - console.log('⚠️ 重定向次数过多,强制跳转到首页'); - redirectCount = 0; - next({ name: "Home" }); - return; - } - next({ name: "Login", query: { redirect: to.fullPath } }); - return; - } - - // 已登录用户访问登录页,根据角色重定向到对应首页 - if (to.name === "Login" && userStore.isLoggedIn) { - const userRole = userStore.userInfo?.roles?.[0]?.role_code; - console.log('🔄 路由守卫 - 已登录用户访问登录页'); - console.log('👤 当前用户角色:', userRole); - console.log('📋 用户信息:', userStore.userInfo); - - // 重置重定向计数器 - redirectCount = 0; - - // 仅管理员角色进入管理端,其余(含未定义)进入用户端 - const adminRoles = [UserRole.ADMIN, UserRole.SYSTEM_ADMIN, UserRole.SUPER_ADMIN]; - const isAdmin = adminRoles.includes(userRole as UserRole); - if (isAdmin) { - console.log('➡️ 重定向到管理端首页'); - next({ name: "Admin" }); - } else { - console.log('➡️ 重定向到用户端首页'); - next({ name: "UserDashboard" }); - } - return; - } - - // 检查角色权限 - if (to.meta.requiresAuth && to.meta.roles && Array.isArray(to.meta.roles)) { - const userRole = userStore.userInfo?.roles?.[0]?.role_code; - - // 特殊处理:如果是管理端路由,使用自定义权限检查 - let hasPermission = false; - if (to.path.startsWith('/admin')) { - // 管理端路由:仅 admin/system_admin/super_admin 可访问 - const adminRoles = [UserRole.ADMIN, UserRole.SYSTEM_ADMIN, UserRole.SUPER_ADMIN]; - hasPermission = adminRoles.includes(userRole as UserRole); - } else { - // 其他路由:使用原有的角色检查逻辑 - hasPermission = to.meta.roles.length === 0 || to.meta.roles.includes(userRole as UserRole); - } - - console.log('🔐 路由权限检查'); - console.log('📍 目标路由:', to.path, to.name); - console.log('🎭 需要的角色:', to.meta.roles); - console.log('👤 用户角色:', userRole); - console.log('🏢 是否为管理端路由:', to.path.startsWith('/admin')); - console.log('✅ 是否有权限:', hasPermission); - - if (!hasPermission) { - console.log('❌ 权限不足,准备重定向'); - - // 增加重定向计数 - redirectCount++; - - // 防止无限重定向 - if (redirectCount > MAX_REDIRECTS) { - console.log('⚠️ 重定向次数过多,强制跳转到首页'); - redirectCount = 0; - next({ name: "Home" }); - return; - } - - // 防止无限重定向:检查是否已经在重定向过程中 - if (to.name === 'Admin' || to.name === 'UserDashboard') { - console.log('⚠️ 检测到重定向循环,强制跳转到首页'); - redirectCount = 0; - next({ name: "Home" }); - return; - } - - // 没有权限,根据用户角色重定向到对应首页 - // 只有normal_user角色跳转到用户端,其他角色(包括未定义的)都跳转到管理端 - if (userRole === 'normal_user') { - console.log('➡️ 重定向到用户端首页'); - next({ name: "UserDashboard" }); - } else { - console.log('➡️ 重定向到管理端首页 (角色:', userRole || '未定义', ')'); - next({ name: "Admin" }); - } - return; - } - } - - // 成功通过所有检查,重置重定向计数器 - redirectCount = 0; - next(); -}); - -// 路由错误处理 -router.onError((error) => { - console.error("路由错误:", error); -}); - -// 输出路由信息 -logRouteInfo(); - -export default router; diff --git a/hertz_server_django_ui/src/router/user_menu_ai.ts b/hertz_server_django_ui/src/router/user_menu_ai.ts deleted file mode 100644 index e228d4e..0000000 --- a/hertz_server_django_ui/src/router/user_menu_ai.ts +++ /dev/null @@ -1,194 +0,0 @@ -import type { RouteRecordRaw } from 'vue-router' -import { defineAsyncComponent } from 'vue' -import { getEnabledModuleKeys, isModuleEnabled } from '@/config/hertz_modules' - -export interface UserMenuConfig { - key: string - label: string - icon?: string - path: string - component: string - children?: UserMenuConfig[] - disabled?: boolean - meta?: { - title?: string - requiresAuth?: boolean - roles?: string[] - [key: string]: any - } - moduleKey?: string -} - -export interface MenuItem { - key: string - label: string - icon?: string - path?: string - children?: MenuItem[] - disabled?: boolean -} - -export const userMenuConfigs: UserMenuConfig[] = [ - { key: 'dashboard', label: '首页', icon: 'DashboardOutlined', path: '/dashboard', component: 'index.vue', meta: { title: '用户首页', requiresAuth: true } }, - { key: 'profile', label: '个人信息', icon: 'UserOutlined', path: '/user/profile', component: 'Profile.vue', meta: { title: '个人信息', requiresAuth: true, hideInMenu: true } }, - // { key: 'documents', label: '文档管理', icon: 'FileTextOutlined', path: '/user/documents', component: 'Documents.vue', meta: { title: '文档管理', requiresAuth: true } }, - { key: 'system-monitor', label: '系统监控', icon: 'DashboardOutlined', path: '/user/system-monitor', component: 'SystemMonitor.vue', meta: { title: '系统监控', requiresAuth: true }, moduleKey: 'user.system-monitor' }, - { key: 'ai-chat', label: 'AI助手', icon: 'MessageOutlined', path: '/user/ai-chat', component: 'AiChat.vue', meta: { title: 'AI助手', requiresAuth: true }, moduleKey: 'user.ai-chat' }, - { key: 'yolo-detection', label: 'YOLO检测', icon: 'ScanOutlined', path: '/user/yolo-detection', component: 'YoloDetection.vue', meta: { title: 'YOLO检测中心', requiresAuth: true }, moduleKey: 'user.yolo-detection' }, - { key: 'live-detection', label: '实时检测', icon: 'VideoCameraOutlined', path: '/user/live-detection', component: 'LiveDetection.vue', meta: { title: '实时检测', requiresAuth: true }, moduleKey: 'user.live-detection' }, - { key: 'detection-history', label: '检测历史', icon: 'HistoryOutlined', path: '/user/detection-history', component: 'DetectionHistory.vue', meta: { title: '检测历史记录', requiresAuth: true }, moduleKey: 'user.detection-history' }, - { key: 'alert-center', label: '告警中心', icon: 'ExclamationCircleOutlined', path: '/user/alert-center', component: 'AlertCenter.vue', meta: { title: '告警中心', requiresAuth: true }, moduleKey: 'user.alert-center' }, - { key: 'notice-center', label: '通知中心', icon: 'BellOutlined', path: '/user/notice', component: 'NoticeCenter.vue', meta: { title: '通知中心', requiresAuth: true }, moduleKey: 'user.notice-center' }, - { key: 'knowledge-center', label: '文章中心', icon: 'DatabaseOutlined', path: '/user/knowledge', component: 'ArticleCenter.vue', meta: { title: '文章中心', requiresAuth: true }, moduleKey: 'user.knowledge-center' }, - { key: 'kb-center', label: '知识库中心', icon: 'DatabaseOutlined', path: '/user/kb-center', component: 'KbCenter.vue', meta: { title: '知识库中心', requiresAuth: true }, moduleKey: 'user.kb-center' }, -] - -const enabledModuleKeys = getEnabledModuleKeys() - -const effectiveUserMenuConfigs: UserMenuConfig[] = userMenuConfigs.filter(config => - isModuleEnabled(config.moduleKey, enabledModuleKeys) -) - -const explicitComponentMap: Record = { - 'index.vue': defineAsyncComponent(() => import('@/views/user_pages/index.vue')), - 'Profile.vue': defineAsyncComponent(() => import('@/views/user_pages/Profile.vue')), - 'Documents.vue': defineAsyncComponent(() => import('@/views/user_pages/Documents.vue')), - 'Messages.vue': defineAsyncComponent(() => import('@/views/user_pages/Messages.vue')), - 'SystemMonitor.vue': defineAsyncComponent(() => import('@/views/user_pages/SystemMonitor.vue')), - 'AiChat.vue': defineAsyncComponent(() => import('@/views/user_pages/AiChat.vue')), - 'YoloDetection.vue': defineAsyncComponent(() => import('@/views/user_pages/YoloDetection.vue')), - 'LiveDetection.vue': defineAsyncComponent(() => import('@/views/user_pages/LiveDetection.vue')), - 'DetectionHistory.vue': defineAsyncComponent(() => import('@/views/user_pages/DetectionHistory.vue')), - 'AlertCenter.vue': defineAsyncComponent(() => import('@/views/user_pages/AlertCenter.vue')), - 'NoticeCenter.vue': defineAsyncComponent(() => import('@/views/user_pages/NoticeCenter.vue')), - 'ArticleCenter.vue': defineAsyncComponent(() => import('@/views/user_pages/ArticleCenter.vue')), - 'KbCenter.vue': defineAsyncComponent(() => import('@/views/user_pages/KbCenter.vue')), -} - -export const userMenuItems: MenuItem[] = effectiveUserMenuConfigs.map(config => ({ - key: config.key, - label: config.label, - icon: config.icon, - path: config.path, - disabled: config.disabled, - children: config.children?.map(child => ({ key: child.key, label: child.label, icon: child.icon, path: child.path, disabled: child.disabled })) -})) - -const componentMap: Record Promise> = { - 'index.vue': () => import('@/views/user_pages/index.vue'), - 'Profile.vue': () => import('@/views/user_pages/Profile.vue'), - 'Documents.vue': () => import('@/views/user_pages/Documents.vue'), - 'Messages.vue': () => import('@/views/user_pages/Messages.vue'), - 'SystemMonitor.vue': () => import('@/views/user_pages/SystemMonitor.vue'), - 'AiChat.vue': () => import('@/views/user_pages/AiChat.vue'), - 'YoloDetection.vue': () => import('@/views/user_pages/YoloDetection.vue'), - 'LiveDetection.vue': () => import('@/views/user_pages/LiveDetection.vue'), - 'DetectionHistory.vue': () => import('@/views/user_pages/DetectionHistory.vue'), - 'AlertCenter.vue': () => import('@/views/user_pages/AlertCenter.vue'), - 'NoticeCenter.vue': () => import('@/views/user_pages/NoticeCenter.vue'), - 'ArticleCenter.vue': () => import('@/views/user_pages/ArticleCenter.vue'), - 'KbCenter.vue': () => import('@/views/user_pages/KbCenter.vue'), -} - -const baseRoutes: RouteRecordRaw[] = effectiveUserMenuConfigs.map(config => { - const route: RouteRecordRaw = { - path: config.path, - name: `User${config.key.charAt(0).toUpperCase() + config.key.slice(1)}`, - component: componentMap[config.component] || (() => import('@/views/NotFound.vue')), - meta: { title: config.meta?.title || config.label, requiresAuth: config.meta?.requiresAuth ?? true, ...config.meta } - } - if (config.children && config.children.length > 0) { - route.children = config.children.map(child => ({ - path: child.path, - name: `User${child.key.charAt(0).toUpperCase() + child.key.slice(1)}`, - component: componentMap[child.component] || (() => import('@/views/NotFound.vue')), - meta: { title: child.meta?.title || child.label, requiresAuth: child.meta?.requiresAuth ?? true, ...child.meta } - })) - } - return route -}) - -// 文章详情独立页面(不在菜单展示) -const knowledgeDetailRoute: RouteRecordRaw = { - path: '/user/knowledge/:id', - name: 'UserKnowledgeDetail', - component: () => import('@/views/user_pages/ArticleDetail.vue'), - meta: { title: '文章详情', requiresAuth: true, hideInMenu: true } -} - -export const userRoutes: RouteRecordRaw[] = [...baseRoutes, knowledgeDetailRoute] - -export function getMenuPath(menuKey: string): string { - const findPath = (items: MenuItem[], key: string): string | null => { - for (const item of items) { - if (item.key === key && item.path) return item.path - if (item.children) { - const childPath = findPath(item.children, key) - if (childPath) return childPath - } - } - return null - } - return findPath(userMenuItems, menuKey) || '/dashboard' -} - -export function getMenuBreadcrumb(menuKey: string): string[] { - const findBreadcrumb = (items: MenuItem[], key: string, path: string[] = []): string[] | null => { - for (const item of items) { - const currentPath = [...path, item.label] - if (item.key === menuKey) return currentPath - if (item.children) { - const childPath = findBreadcrumb(item.children, key, currentPath) - if (childPath) return childPath - } - } - return null - } - return findBreadcrumb(userMenuItems, menuKey) || ['仪表盘'] -} - -export const generateComponentMap = () => { - const map: Record = {} - const processConfigs = (configs: UserMenuConfig[]) => { - configs.forEach(config => { - if (explicitComponentMap[config.component]) { - map[config.key] = explicitComponentMap[config.component] - } else { - map[config.key] = defineAsyncComponent(() => import('@/views/NotFound.vue')) - } - if (config.children) processConfigs(config.children) - }) - } - processConfigs(effectiveUserMenuConfigs) - return map -} - -export const userComponentMap = generateComponentMap() - -export const getFilteredUserMenuItems = (userRoles: string[], userPermissions: string[]): MenuItem[] => { - return effectiveUserMenuConfigs - .filter(config => { - // 隐藏菜单中不显示的项(如个人信息,只在用户下拉菜单中显示) - if (config.meta?.hideInMenu) return false - if (!config.meta?.roles || config.meta.roles.length === 0) return true - return config.meta.roles.some(requiredRole => userRoles.includes(requiredRole)) - }) - .map(config => ({ - key: config.key, - label: config.label, - icon: config.icon, - path: config.path, - disabled: config.disabled, - children: config.children?.filter(child => { - if (!child.meta?.roles || child.meta.roles.length === 0) return true - return child.meta.roles.some(requiredRole => userRoles.includes(requiredRole)) - }).map(child => ({ key: child.key, label: child.label, icon: child.icon, path: child.path, disabled: child.disabled })) - })) -} - -export const hasUserMenuPermission = (menuKey: string, userRoles: string[]): boolean => { - const menuConfig = userMenuConfigs.find(config => config.key === menuKey) - if (!menuConfig) return false - if (!menuConfig.meta?.roles || menuConfig.meta.roles.length === 0) return true - return menuConfig.meta.roles.some(requiredRole => userRoles.includes(requiredRole)) -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/stores/hertz_app.ts b/hertz_server_django_ui/src/stores/hertz_app.ts deleted file mode 100644 index 613dbb5..0000000 --- a/hertz_server_django_ui/src/stores/hertz_app.ts +++ /dev/null @@ -1,98 +0,0 @@ -import { defineStore } from 'pinia' -import { ref, computed } from 'vue' -import { i18n } from '@/locales' - -// 主题类型 -export type Theme = 'light' | 'dark' | 'auto' - -// 语言类型 -export type Language = 'zh-CN' | 'en-US' - -export const useAppStore = defineStore('app', () => { - // 状态 - const theme = ref('light') - const language = ref('zh-CN') - const collapsed = ref(false) - const loading = ref(false) - - // 计算属性 - const isDark = computed(() => { - if (theme.value === 'auto') { - return window.matchMedia('(prefers-color-scheme: dark)').matches - } - return theme.value === 'dark' - }) - - const currentLanguage = computed(() => language.value) - - // 方法 - const setTheme = (newTheme: Theme) => { - theme.value = newTheme - localStorage.setItem('theme', newTheme) - - // 应用主题到HTML - const html = document.documentElement - if (newTheme === 'dark' || (newTheme === 'auto' && isDark.value)) { - html.classList.add('dark') - } else { - html.classList.remove('dark') - } - } - - const setLanguage = (newLanguage: Language) => { - language.value = newLanguage - localStorage.setItem('language', newLanguage) - - // 设置i18n语言 - i18n.global.locale.value = newLanguage - } - - const toggleCollapsed = () => { - collapsed.value = !collapsed.value - } - - const setLoading = (state: boolean) => { - loading.value = state - } - - const initAppSettings = () => { - // 从本地存储恢复设置 - const savedTheme = localStorage.getItem('theme') as Theme - const savedLanguage = localStorage.getItem('language') as Language - - if (savedTheme) { - setTheme(savedTheme) - } - - if (savedLanguage) { - setLanguage(savedLanguage) - } else { - // 根据浏览器语言自动设置 - const browserLang = navigator.language - if (browserLang.startsWith('zh')) { - setLanguage('zh-CN') - } else { - setLanguage('en-US') - } - } - } - - return { - // 状态 - theme, - language, - collapsed, - loading, - - // 计算属性 - isDark, - currentLanguage, - - // 方法 - setTheme, - setLanguage, - toggleCollapsed, - setLoading, - initAppSettings, - } -}) diff --git a/hertz_server_django_ui/src/stores/hertz_theme.ts b/hertz_server_django_ui/src/stores/hertz_theme.ts deleted file mode 100644 index 8a3c412..0000000 --- a/hertz_server_django_ui/src/stores/hertz_theme.ts +++ /dev/null @@ -1,101 +0,0 @@ -import { defineStore } from 'pinia' -import { ref, watch } from 'vue' - -// 主题配置接口 -export interface ThemeConfig { - // 导航栏 - headerBg: string - headerText: string - headerBorder: string - - // 背景 - pageBg: string - contentBg: string - - // 组件背景 - cardBg: string - cardBorder: string - - // 主色调 - primaryColor: string - textPrimary: string - textSecondary: string -} - -// 默认主题 -const defaultTheme: ThemeConfig = { - headerBg: '#ffffff', - headerText: '#111827', - headerBorder: '#e5e7eb', - pageBg: '#ffffff', - contentBg: '#ffffff', - cardBg: '#ffffff', - cardBorder: '#e5e7eb', - primaryColor: '#2563eb', - textPrimary: '#111827', - textSecondary: '#6b7280', -} - -export const useThemeStore = defineStore('theme', () => { - const theme = ref({ ...defaultTheme }) - - // 从 localStorage 加载主题 - const loadTheme = () => { - const savedTheme = localStorage.getItem('customTheme') - if (savedTheme) { - try { - theme.value = { ...defaultTheme, ...JSON.parse(savedTheme) } - applyTheme(theme.value) - } catch (e) { - console.error('Failed to load theme:', e) - } - } else { - applyTheme(theme.value) - } - } - - // 应用主题 - const applyTheme = (config: ThemeConfig) => { - const root = document.documentElement - - // 设置 CSS 变量 - root.style.setProperty('--theme-header-bg', config.headerBg) - root.style.setProperty('--theme-header-text', config.headerText) - root.style.setProperty('--theme-header-border', config.headerBorder) - root.style.setProperty('--theme-page-bg', config.pageBg) - root.style.setProperty('--theme-content-bg', config.contentBg) - root.style.setProperty('--theme-card-bg', config.cardBg) - root.style.setProperty('--theme-card-border', config.cardBorder) - root.style.setProperty('--theme-primary', config.primaryColor) - root.style.setProperty('--theme-text-primary', config.textPrimary) - root.style.setProperty('--theme-text-secondary', config.textSecondary) - } - - // 更新主题 - const updateTheme = (newTheme: Partial) => { - theme.value = { ...theme.value, ...newTheme } - applyTheme(theme.value) - localStorage.setItem('customTheme', JSON.stringify(theme.value)) - } - - // 重置主题 - const resetTheme = () => { - theme.value = { ...defaultTheme } - applyTheme(theme.value) - localStorage.removeItem('customTheme') - } - - // 监听主题变化,自动应用 - watch(theme, (newTheme) => { - applyTheme(newTheme) - }, { deep: true }) - - return { - theme, - loadTheme, - updateTheme, - resetTheme, - applyTheme, - } -}) - diff --git a/hertz_server_django_ui/src/stores/hertz_user.ts b/hertz_server_django_ui/src/stores/hertz_user.ts deleted file mode 100644 index 1cc8d41..0000000 --- a/hertz_server_django_ui/src/stores/hertz_user.ts +++ /dev/null @@ -1,261 +0,0 @@ -import { defineStore } from 'pinia' -import { ref, computed } from 'vue' -import { request } from '@/utils/hertz_request' -import { changePassword } from '@/api/password' -import type { ChangePasswordParams } from '@/api/password' -import { roleApi } from '@/api/role' -import { initializeMenuMapping } from '@/utils/menu_mapping' -import { logoutUser } from '@/api/auth' -import { hasModuleSelection } from '@/config/hertz_modules' - -// 用户信息接口 -interface UserInfo { - user_id: number - username: string - email: string - phone?: string - real_name?: string - avatar?: string - roles: Array<{ - role_id: number - role_name: string - role_code: string - }> - permissions: string[] - menu_permissions?: number[] // 用户拥有的菜单权限ID列表 -} - -// 登录参数接口 -interface LoginParams { - username: string - password: string - remember?: boolean -} - -export const useUserStore = defineStore('user', () => { - // 状态 - const userInfo = ref(null) - const token = ref('') - const isLoggedIn = ref(false) - const loading = ref(false) - const userMenuPermissions = ref([]) // 用户菜单权限ID列表 - - // 计算属性 - const hasPermission = computed(() => (permission: string) => { - return userInfo.value?.permissions?.includes(permission) || false - }) - - const isAdmin = computed(() => { - const userRole = userInfo.value?.roles?.[0]?.role_code - return userRole === 'admin' || userRole === 'system_admin' || userRole === 'super_admin' - }) - - // 方法 - const login = async (params: LoginParams) => { - loading.value = true - try { - const response = await request.post<{ - access_token: string - refresh_token: string - user_info: UserInfo - }>('/api/auth/login/', params) - - token.value = response.access_token - userInfo.value = response.user_info - isLoggedIn.value = true - - // 保存到本地存储 - localStorage.setItem('token', response.access_token) - if (params.remember) { - localStorage.setItem('userInfo', JSON.stringify(response.user_info)) - } - - // 获取用户菜单权限(模板模式首次运行时跳过) - const isTemplateMode = import.meta.env.VITE_TEMPLATE_SETUP_MODE === 'true' - if (!isTemplateMode || hasModuleSelection()) { - await fetchUserMenuPermissions() - } - - return response - } catch (error) { - console.error('登录失败:', error) - throw error - } finally { - loading.value = false - } - } - - const logout = async () => { - loading.value = true - try { - // 调用封装好的退出登录接口 - await logoutUser() - - // 清除状态 - token.value = '' - userInfo.value = null - isLoggedIn.value = false - - // 清除本地存储 - localStorage.removeItem('token') - localStorage.removeItem('userInfo') - - } catch (error) { - console.error('退出登录失败:', error) - // 即使请求失败也要清除本地状态 - token.value = '' - userInfo.value = null - isLoggedIn.value = false - localStorage.removeItem('token') - localStorage.removeItem('userInfo') - } finally { - loading.value = false - } - } - - const updateUserInfo = async (info: Partial) => { - try { - const response = await request.put('/user/profile', info) - - userInfo.value = { ...userInfo.value, ...response } - localStorage.setItem('userInfo', JSON.stringify(userInfo.value)) - - return response - } catch (error) { - console.error('更新用户信息失败:', error) - throw error - } - } - - const checkAuth = async () => { - console.log('🔍 检查用户认证状态...') - - const savedToken = localStorage.getItem('token') - const savedUserInfo = localStorage.getItem('userInfo') - - console.log('💾 localStorage中的token:', savedToken ? '存在' : '不存在') - console.log('💾 localStorage中的userInfo:', savedUserInfo ? '存在' : '不存在') - - if (savedToken && savedUserInfo) { - try { - const parsedUserInfo = JSON.parse(savedUserInfo) - token.value = savedToken - userInfo.value = parsedUserInfo - isLoggedIn.value = true - - console.log('✅ 用户状态恢复成功') - console.log('👤 恢复的用户信息:', parsedUserInfo) - console.log('🔐 登录状态:', isLoggedIn.value) - - // 获取用户菜单权限(模板模式首次运行时跳过) - const isTemplateMode = import.meta.env.VITE_TEMPLATE_SETUP_MODE === 'true' - if (!isTemplateMode || hasModuleSelection()) { - await fetchUserMenuPermissions() - } - } catch (error) { - console.error('❌ 解析用户信息失败:', error) - clearAuth() - } - } else { - console.log('❌ 没有找到保存的认证信息') - } - } - - const clearAuth = () => { - token.value = '' - userInfo.value = null - isLoggedIn.value = false - userMenuPermissions.value = [] - localStorage.removeItem('token') - localStorage.removeItem('userInfo') - } - - const updatePassword = async (params: ChangePasswordParams) => { - try { - await changePassword(params) - return true - } catch (error) { - console.error('修改密码失败:', error) - throw error - } - } - - // 获取用户菜单权限 - const fetchUserMenuPermissions = async () => { - if (!userInfo.value?.roles?.length) { - userMenuPermissions.value = [] - return [] - } - - try { - const adminRoleCodes = ['admin', 'system_admin', 'super_admin'] - const hasAdminRole = userInfo.value.roles.some(role => adminRoleCodes.includes(role.role_code)) - - if (!hasAdminRole) { - userMenuPermissions.value = [] - return [] - } - - // 获取用户所有角色的菜单权限 - const allMenuPermissions = new Set() - - for (const role of userInfo.value.roles) { - try { - const response = await roleApi.getRolePermissions(role.role_id) - if (response.success) { - const menuIds = response.data.list || response.data - if (Array.isArray(menuIds)) { - menuIds.forEach((menuId: any) => { - const id = typeof menuId === 'number' ? menuId : Number(menuId) - if (!isNaN(id)) { - allMenuPermissions.add(id) - } - }) - } - } - } catch (error) { - console.error(`获取角色 ${role.role_name} 的菜单权限失败:`, error) - } - } - - const permissions = Array.from(allMenuPermissions) - userMenuPermissions.value = permissions - - // 同时更新用户信息中的菜单权限 - if (userInfo.value) { - userInfo.value.menu_permissions = permissions - } - - // 初始化菜单映射关系 - await initializeMenuMapping() - - return permissions - } catch (error) { - console.error('获取用户菜单权限失败:', error) - userMenuPermissions.value = [] - return [] - } - } - - return { - // 状态 - userInfo, - token, - isLoggedIn, - loading, - userMenuPermissions, - - // 计算属性 - hasPermission, - isAdmin, - - // 方法 - login, - logout, - updateUserInfo, - checkAuth, - clearAuth, - updatePassword, - fetchUserMenuPermissions, - } -}) diff --git a/hertz_server_django_ui/src/styles/index.scss b/hertz_server_django_ui/src/styles/index.scss deleted file mode 100644 index 3f9f844..0000000 --- a/hertz_server_django_ui/src/styles/index.scss +++ /dev/null @@ -1,422 +0,0 @@ -// 全局样式入口文件 -@use 'variables' as *; -@use 'sass:color'; - -// 全局样式 -* { - margin: 0; - padding: 0; - box-sizing: border-box; -} - -html, body { - height: 100%; - font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, 'Noto Sans', sans-serif; - -webkit-font-smoothing: antialiased; - -moz-osx-font-smoothing: grayscale; - background: #ffffff; - color: #111827; -} - -#app { - height: 100%; -} - -// 按钮样式 -.btn { - @include transition(all); - padding: $spacing-3 $spacing-6; - border: 1px solid $gray-300; - border-radius: $radius-md; - background: $bg-primary; - cursor: pointer; - font-weight: 500; - font-size: $font-size-sm; - line-height: 1.5; - - &:hover { - border-color: $primary-color; - } - - &.btn-primary { - @include button-style($primary-color, white); - } - - &.btn-secondary { - background: $bg-primary; - color: $gray-700; - border-color: $gray-300; - - &:hover { - background: $gray-50; - border-color: $primary-color; - color: $primary-color; - } - } - - &.btn-success { - @include button-style($success-color, white); - } - - &.btn-danger { - @include button-style($error-color, white); - } - - &.btn-warning { - @include button-style($warning-color, white); - } -} - -// 卡片样式 -.card { - @include card-style; - padding: $spacing-6; - margin-bottom: $spacing-4; -} - -// 表单样式 -.form-item { - margin-bottom: $spacing-4; - - label { - display: block; - margin-bottom: $spacing-2; - font-weight: 500; - color: $gray-700; - font-size: $font-size-sm; - } - - input, select, textarea { - width: 100%; - padding: $spacing-3; - border: 1px solid $gray-300; - border-radius: $radius-md; - font-size: $font-size-sm; - transition: border-color 0.2s; - - &:focus { - outline: none; - border-color: $primary-color; - box-shadow: 0 0 0 3px rgba(37, 99, 235, 0.1); - } - - &:hover { - border-color: $gray-400; - } - } -} - -// 布局辅助类 -.flex-center { - @include flex-center; -} - -.text-ellipsis { - @include text-ellipsis; -} - -// 间距辅助类 -.m-0 { margin: $spacing-0; } -.m-1 { margin: $spacing-1; } -.m-2 { margin: $spacing-2; } -.m-3 { margin: $spacing-3; } -.m-4 { margin: $spacing-4; } -.m-5 { margin: $spacing-5; } -.m-6 { margin: $spacing-6; } -.m-8 { margin: $spacing-8; } - -.p-0 { padding: $spacing-0; } -.p-1 { padding: $spacing-1; } -.p-2 { padding: $spacing-2; } -.p-3 { padding: $spacing-3; } -.p-4 { padding: $spacing-4; } -.p-5 { padding: $spacing-5; } -.p-6 { padding: $spacing-6; } -.p-8 { padding: $spacing-8; } - -// ==================== 全局弹窗美化样式 - 苹果风格 ==================== -// 弹窗遮罩层 -.ant-modal-mask { - background: rgba(0, 0, 0, 0.4); - backdrop-filter: blur(8px); - -webkit-backdrop-filter: blur(8px); - transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94); -} - -// 弹窗容器 -.ant-modal-wrap { - display: flex; - align-items: center; - justify-content: center; - padding: 20px; -} - -// 统一按钮主题 - 苹果风格 -.ant-btn { - border-radius: 12px; - font-weight: 500; - padding: 0 20px; - height: 40px; - transition: all 0.25s cubic-bezier(0.25, 0.46, 0.45, 0.94); - border-width: 0.5px; - - &.ant-btn-default { - background: rgba(255, 255, 255, 0.8); - border-color: rgba(0, 0, 0, 0.12); - color: #1d1d1f; - - &:hover { - background: rgba(0, 0, 0, 0.04); - border-color: rgba(0, 0, 0, 0.16); - transform: translateY(-1px); - } - } - - &.ant-btn-primary { - background: #3b82f6; - border-color: #3b82f6; - box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3); - - &:hover { - background: #2563eb; - border-color: #2563eb; - transform: translateY(-1px); - box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4); - } - - &:active { transform: translateY(0); } - } - - &.ant-btn-dangerous:not(.ant-btn-link) { - color: #ef4444; - border-color: rgba(239, 68, 68, 0.3); - - &:hover { - background: rgba(239, 68, 68, 0.1); - border-color: #ef4444; - color: #ef4444; - } - } - - &.ant-btn-link { - border-radius: 8px; - } - - &.ant-btn-sm { - border-radius: 8px; - height: 30px; - padding: 0 14px; - } -} - -// 弹窗内容 - 苹果风格 -.ant-modal { - top: 0; - padding-bottom: 0; - - .ant-modal-content { - background: rgba(255, 255, 255, 0.95); - backdrop-filter: saturate(180%) blur(20px); - -webkit-backdrop-filter: saturate(180%) blur(20px); - border-radius: 20px; - box-shadow: - 0 20px 60px rgba(0, 0, 0, 0.2), - 0 0 0 0.5px rgba(0, 0, 0, 0.08); - overflow: hidden; - animation: modalSlideIn 0.4s cubic-bezier(0.34, 1.56, 0.64, 1); - transition: all 0.3s cubic-bezier(0.25, 0.46, 0.45, 0.94); - } - - // 弹窗头部 - .ant-modal-header { - background: rgba(255, 255, 255, 0.8); - backdrop-filter: saturate(180%) blur(20px); - -webkit-backdrop-filter: saturate(180%) blur(20px); - border-bottom: 0.5px solid rgba(0, 0, 0, 0.08); - padding: 24px 28px; - border-radius: 20px 20px 0 0; - - .ant-modal-title { - font-weight: 600; - color: #1d1d1f; - font-size: 20px; - letter-spacing: -0.3px; - line-height: 1.3; - } - - .ant-modal-close { - top: 24px; - right: 28px; - width: 32px; - height: 32px; - border-radius: 8px; - transition: all 0.25s cubic-bezier(0.25, 0.46, 0.45, 0.94); - - &:hover { - background: rgba(0, 0, 0, 0.06); - transform: scale(1.1); - } - - .ant-modal-close-x { - width: 32px; - height: 32px; - line-height: 32px; - font-size: 16px; - color: #86868b; - transition: color 0.2s ease; - - &:hover { color: #1d1d1f; } - } - } - } - - // 弹窗主体 - .ant-modal-body { - padding: 28px; - background: rgba(255, 255, 255, 0.95); - color: #1d1d1f; - line-height: 1.6; - } - - // 弹窗底部 - .ant-modal-footer { - background: rgba(255, 255, 255, 0.8); - backdrop-filter: saturate(180%) blur(20px); - -webkit-backdrop-filter: saturate(180%) blur(20px); - border-top: 0.5px solid rgba(0, 0, 0, 0.08); - padding: 20px 28px; - border-radius: 0 0 20px 20px; - - .ant-btn { - border-radius: 10px; - height: 40px; - padding: 0 20px; - font-weight: 500; - font-size: 14px; - transition: all 0.25s cubic-bezier(0.25, 0.46, 0.45, 0.94); - border: 0.5px solid rgba(0, 0, 0, 0.12); - - &:not(.ant-btn-primary) { - background: rgba(255, 255, 255, 0.8); - color: #1d1d1f; - - &:hover { - background: rgba(0, 0, 0, 0.04); - border-color: rgba(0, 0, 0, 0.16); - transform: translateY(-1px); - box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1); - } - } - - &.ant-btn-primary { - background: #3b82f6; - border-color: #3b82f6; - box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3); - - &:hover { - background: #2563eb; - border-color: #2563eb; - transform: translateY(-1px); - box-shadow: 0 4px 12px rgba(59, 130, 246, 0.4); - } - - &:active { transform: translateY(0); } - } - - &.ant-btn-dangerous { - color: #ef4444; - border-color: rgba(239, 68, 68, 0.3); - - &:hover { - background: rgba(239, 68, 68, 0.1); - border-color: #ef4444; - color: #ef4444; - } - } - } - } - - // 表单元素美化 - .ant-form-item-label > label { - font-weight: 500; - color: #1d1d1f; - font-size: 14px; - letter-spacing: -0.1px; - } - - .ant-input, - .ant-select-selector, - .ant-input-number, - .ant-picker, - .ant-textarea, - .ant-tree-select-selector { - border-radius: 10px; - border: 0.5px solid rgba(0, 0, 0, 0.12); - background: rgba(255, 255, 255, 0.9); - transition: all 0.25s cubic-bezier(0.25, 0.46, 0.45, 0.94); - font-size: 14px; - - &:hover { border-color: #3b82f6; background: rgba(255, 255, 255, 1); } - &:focus, - &.ant-input-focused, - &.ant-select-focused .ant-select-selector, - &.ant-picker-focused { border-color: #3b82f6; box-shadow: 0 0 0 3px rgba(59, 130, 246, 0.1); background: rgba(255, 255, 255, 1); } - } - - .ant-input-number { width: 100%; } - - .ant-radio-group { - .ant-radio-button-wrapper { - border-radius: 8px; - border: 0.5px solid rgba(0, 0, 0, 0.12); - transition: all 0.25s cubic-bezier(0.25, 0.46, 0.45, 0.94); - &:hover { border-color: #3b82f6; } - &.ant-radio-button-wrapper-checked { background: #3b82f6; border-color: #3b82f6; box-shadow: 0 2px 8px rgba(59, 130, 246, 0.3); } - } - } - - .ant-switch { background: rgba(0, 0, 0, 0.25); &.ant-switch-checked { background: #10b981; } } - - // 表格在弹窗中的样式 - .ant-table { - background: transparent; - .ant-table-thead > tr > th { - background: rgba(0, 0, 0, 0.02); - font-weight: 600; - color: #1d1d1f; - border-bottom: 0.5px solid rgba(0, 0, 0, 0.08); - padding: 16px; - font-size: 13px; - } - .ant-table-tbody > tr { - transition: all 0.2s ease; - &:hover > td { background: rgba(0, 0, 0, 0.02); } - > td { padding: 16px; border-bottom: 0.5px solid rgba(0, 0, 0, 0.06); color: #1d1d1f; } - } - } - - // 标签样式 - .ant-tag { border-radius: 6px; font-weight: 500; padding: 2px 10px; border: 0.5px solid currentColor; opacity: 0.8; } - - // 描述列表样式 - .ant-descriptions { - .ant-descriptions-item-label { font-weight: 500; color: #1d1d1f; background: rgba(0, 0, 0, 0.02); } - .ant-descriptions-item-content { color: #86868b; } - } -} - -// 弹窗动画 -@keyframes modalSlideIn { - from { opacity: 0; transform: scale(0.95) translateY(-20px); } - to { opacity: 1; transform: scale(1) translateY(0); } -} - -// 响应式优化 -@media (max-width: 768px) { - .ant-modal { - .ant-modal-content { border-radius: 16px; } - .ant-modal-header { padding: 20px 20px; border-radius: 16px 16px 0 0; .ant-modal-title { font-size: 18px; } } - .ant-modal-body { padding: 20px; } - .ant-modal-footer { padding: 16px 20px; border-radius: 0 0 16px 16px; } - } -} diff --git a/hertz_server_django_ui/src/styles/variables.scss b/hertz_server_django_ui/src/styles/variables.scss deleted file mode 100644 index c596a87..0000000 --- a/hertz_server_django_ui/src/styles/variables.scss +++ /dev/null @@ -1,124 +0,0 @@ -// 全局变量文件 - 简约现代风格 - -// 颜色系统 -$primary-color: #2563eb; -$primary-light: #3b82f6; -$primary-dark: #1d4ed8; -$success-color: #10b981; -$warning-color: #f59e0b; -$error-color: #ef4444; -$info-color: #06b6d4; - -// 中性色系统 -$gray-50: #f9fafb; -$gray-100: #f3f4f6; -$gray-200: #e5e7eb; -$gray-300: #d1d5db; -$gray-400: #9ca3af; -$gray-500: #6b7280; -$gray-600: #4b5563; -$gray-700: #374151; -$gray-800: #1f2937; -$gray-900: #111827; - -// 背景色 -$bg-primary: #ffffff; -$bg-secondary: #f9fafb; -$bg-tertiary: #f3f4f6; - -// 字体大小 -$font-size-xs: 12px; -$font-size-sm: 14px; -$font-size-base: 16px; -$font-size-lg: 18px; -$font-size-xl: 20px; -$font-size-2xl: 24px; -$font-size-3xl: 30px; -$font-size-4xl: 36px; - -// 间距系统 - 4px基础单位 -$spacing-0: 0; -$spacing-1: 4px; -$spacing-2: 8px; -$spacing-3: 12px; -$spacing-4: 16px; -$spacing-5: 20px; -$spacing-6: 24px; -$spacing-8: 32px; -$spacing-10: 40px; -$spacing-12: 48px; -$spacing-16: 64px; -$spacing-20: 80px; - -// 圆角系统 -$radius-none: 0; -$radius-sm: 4px; -$radius-md: 6px; -$radius-lg: 8px; -$radius-xl: 12px; -$radius-2xl: 16px; -$radius-full: 9999px; - -// 阴影系统 -$shadow-sm: 0 1px 2px 0 rgba(0, 0, 0, 0.05); -$shadow-md: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); -$shadow-lg: 0 10px 15px -3px rgba(0, 0, 0, 0.1), 0 4px 6px -2px rgba(0, 0, 0, 0.05); -$shadow-xl: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04); - -// 过渡时间 -$transition-fast: 0.15s; -$transition-normal: 0.2s; -$transition-slow: 0.3s; - -// 混合器 -@mixin flex-center { - display: flex; - justify-content: center; - align-items: center; -} - -@mixin text-ellipsis { - overflow: hidden; - text-overflow: ellipsis; - white-space: nowrap; -} - -@mixin box-shadow($shadow: $shadow-md) { - box-shadow: $shadow; -} - -@mixin transition($property: all, $duration: $transition-normal) { - transition: #{$property} #{$duration} ease; -} - -@mixin card-style { - background: $bg-primary; - border-radius: $radius-lg; - box-shadow: $shadow-sm; - border: 1px solid $gray-200; - transition: all 0.2s ease; - - &:hover { - box-shadow: $shadow-md; - } -} - -@mixin button-style($bg-color: $primary-color, $text-color: white) { - background: $bg-color; - color: $text-color; - border: 1px solid $bg-color; - border-radius: $radius-md; - padding: $spacing-3 $spacing-6; - font-weight: 500; - cursor: pointer; - transition: background 0.2s ease; - - &:hover { - background: darken($bg-color, 8%); - border-color: darken($bg-color, 8%); - } - - &:active { - background: darken($bg-color, 12%); - } -} diff --git a/hertz_server_django_ui/src/types/env.d.ts b/hertz_server_django_ui/src/types/env.d.ts deleted file mode 100644 index bc33386..0000000 --- a/hertz_server_django_ui/src/types/env.d.ts +++ /dev/null @@ -1,13 +0,0 @@ -/// - -interface ImportMetaEnv { - readonly VITE_API_BASE_URL: string - readonly VITE_APP_TITLE: string - readonly VITE_APP_VERSION: string - readonly VITE_DEV_SERVER_HOST: string - readonly VITE_DEV_SERVER_PORT: string -} - -interface ImportMeta { - readonly env: ImportMetaEnv -} diff --git a/hertz_server_django_ui/src/types/hertz_types.ts b/hertz_server_django_ui/src/types/hertz_types.ts deleted file mode 100644 index 2efd64d..0000000 --- a/hertz_server_django_ui/src/types/hertz_types.ts +++ /dev/null @@ -1,182 +0,0 @@ -// 通用响应类型 -export interface ApiResponse { - code: number - message: string - data: T - success?: boolean - timestamp?: string -} - -// 分页请求参数 -export interface PageParams { - page: number - pageSize: number - sortBy?: string - sortOrder?: 'asc' | 'desc' -} - -// 分页响应 -export interface PageResponse { - list: T[] - total: number - page: number - pageSize: number - totalPages: number -} - -// 用户相关类型 -export interface User { - id: number - username: string - email: string - avatar?: string - role: string - permissions: string[] - status: 'active' | 'inactive' | 'banned' - createTime: string - updateTime: string -} - -export interface LoginParams { - username: string - password: string - remember?: boolean -} - -export interface LoginResponse { - token: string - user: User - expiresIn: number -} - -// 菜单相关类型 -export interface MenuItem { - id: number - name: string - path: string - icon?: string - children?: MenuItem[] - permission?: string - hidden?: boolean - meta?: { - title: string - requiresAuth?: boolean - } -} - -// 表格相关类型 -export interface TableColumn { - key: string - title: string - width?: number - fixed?: 'left' | 'right' - sortable?: boolean - render?: (record: T, index: number) => any -} - -export interface TableProps { - data: T[] - columns: TableColumn[] - loading?: boolean - pagination?: { - current: number - pageSize: number - total: number - showSizeChanger?: boolean - showQuickJumper?: boolean - } - rowSelection?: { - selectedRowKeys: (string | number)[] - onChange: (selectedRowKeys: (string | number)[], selectedRows: T[]) => void - } -} - -// 表单相关类型 -export interface FormField { - name: string - label: string - type: 'input' | 'select' | 'textarea' | 'date' | 'switch' | 'radio' | 'checkbox' - required?: boolean - placeholder?: string - options?: { label: string; value: any }[] - rules?: any[] -} - -export interface FormProps { - fields: FormField[] - initialValues?: Record - onSubmit: (values: Record) => Promise - loading?: boolean -} - -// 弹窗相关类型 -export interface ModalProps { - title: string - visible: boolean - onCancel: () => void - onOk?: () => void - width?: number - children: any -} - -// 消息相关类型 -export type MessageType = 'success' | 'error' | 'warning' | 'info' - -export interface MessageConfig { - type: MessageType - content: string - duration?: number -} - -// 主题相关类型 -export type Theme = 'light' | 'dark' | 'auto' - -// 语言相关类型 -export type Language = 'zh-CN' | 'en-US' - -// 路由相关类型 -export interface RouteMeta { - title?: string - requiresAuth?: boolean - permission?: string - hidden?: boolean - icon?: string -} - -// 组件属性类型 -export interface ComponentProps { - className?: string - style?: Record - children?: any -} - -// 工具函数类型 -export type DeepPartial = { - [P in keyof T]?: T[P] extends object ? DeepPartial : T[P] -} - -export type Optional = Omit & Partial> - -// API 相关类型 -export interface RequestConfig { - showLoading?: boolean - showError?: boolean - timeout?: number -} - -// 文件相关类型 -export interface FileInfo { - name: string - size: number - type: string - url?: string - lastModified: number -} - -export interface UploadProps { - accept?: string - multiple?: boolean - maxSize?: number - onUpload: (files: File[]) => Promise - onRemove?: (file: FileInfo) => void -} diff --git a/hertz_server_django_ui/src/utils/hertz_captcha.ts b/hertz_server_django_ui/src/utils/hertz_captcha.ts deleted file mode 100644 index eca9017..0000000 --- a/hertz_server_django_ui/src/utils/hertz_captcha.ts +++ /dev/null @@ -1,70 +0,0 @@ -import { generateCaptcha, refreshCaptcha, type CaptchaResponse, type CaptchaRefreshResponse } from '@/api/captcha' -import { ref, type Ref } from 'vue' - -/** - * 验证码组合式函数 - */ -export function useCaptcha() { - // 验证码数据 - const captchaData: Ref = ref(null) - - // 加载状态 - const captchaLoading: Ref = ref(false) - - // 错误信息 - const captchaError: Ref = ref(null) - - /** - * 生成验证码 - */ - const handleGenerateCaptcha = async (): Promise => { - try { - captchaLoading.value = true - captchaError.value = null - - const response = await generateCaptcha() - captchaData.value = response - } catch (error) { - console.error('生成验证码失败:', error) - captchaError.value = error instanceof Error ? error.message : '生成验证码失败' - } finally { - captchaLoading.value = false - } - } - - /** - * 刷新验证码 - */ - const handleRefreshCaptcha = async (): Promise => { - try { - captchaLoading.value = true - captchaError.value = null - - // 检查是否有当前验证码ID - if (!captchaData.value?.captcha_id) { - console.warn('没有当前验证码ID,将生成新的验证码') - await handleGenerateCaptcha() - return - } - - const response = await refreshCaptcha(captchaData.value.captcha_id) - captchaData.value = response - } catch (error) { - console.error('刷新验证码失败:', error) - captchaError.value = error instanceof Error ? error.message : '刷新验证码失败' - } finally { - captchaLoading.value = false - } - } - - return { - captchaData, - captchaLoading, - captchaError, - generateCaptcha: handleGenerateCaptcha, - refreshCaptcha: handleRefreshCaptcha - } -} - -// 导出类型 -export type { CaptchaResponse, CaptchaRefreshResponse } \ No newline at end of file diff --git a/hertz_server_django_ui/src/utils/hertz_env.ts b/hertz_server_django_ui/src/utils/hertz_env.ts deleted file mode 100644 index 62eb07e..0000000 --- a/hertz_server_django_ui/src/utils/hertz_env.ts +++ /dev/null @@ -1,87 +0,0 @@ -/** - * 环境变量检查工具 - * 用于在开发环境中检查环境变量配置是否正确 - */ - -// 检查环境变量配置 -export const checkEnvironmentVariables = () => { - console.log('🔧 环境变量检查') - - // 在Vite中,环境变量可能通过define选项直接定义 - // 或者通过import.meta.env读取 - const apiBaseUrl = import.meta.env.VITE_API_BASE_URL || 'http://localhost:3000' - const appTitle = import.meta.env.VITE_APP_TITLE || 'Hertz Admin' - const appVersion = import.meta.env.VITE_APP_VERSION || '1.0.0' - - // 检查必需的环境变量 - const requiredVars = [ - { key: 'VITE_API_BASE_URL', value: apiBaseUrl }, - { key: 'VITE_APP_TITLE', value: appTitle }, - { key: 'VITE_APP_VERSION', value: appVersion }, - ] - - requiredVars.forEach(({ key, value }) => { - if (value) { - console.log(`✅ ${key}: ${value}`) - } else { - console.warn(`❌ ${key}: 未设置`) - } - }) - - // 检查可选的环境变量 - const devServerHost = import.meta.env.VITE_DEV_SERVER_HOST || 'localhost' - const devServerPort = import.meta.env.VITE_DEV_SERVER_PORT || '3000' - - const optionalVars = [ - { key: 'VITE_DEV_SERVER_HOST', value: devServerHost }, - { key: 'VITE_DEV_SERVER_PORT', value: devServerPort }, - ] - - optionalVars.forEach(({ key, value }) => { - if (value) { - console.log(`ℹ️ ${key}: ${value}`) - } else { - console.log(`➖ ${key}: 未设置(使用默认值)`) - } - }) - - console.log('🎉 环境变量检查完成') -} - -// 验证环境变量是否有效 -export const validateEnvironment = () => { - // 检查API基础地址 - if (!import.meta.env.VITE_API_BASE_URL) { - console.warn('⚠️ VITE_API_BASE_URL 未设置,将使用默认值') - } - - // 检查应用配置 - if (!import.meta.env.VITE_APP_TITLE) { - console.warn('⚠️ VITE_APP_TITLE 未设置,将使用默认值') - } - - if (!import.meta.env.VITE_APP_VERSION) { - console.warn('⚠️ VITE_APP_VERSION 未设置,将使用默认值') - } - - return { - isValid: true, - warnings: [] - } -} - -// 获取API基础地址 -export const getApiBaseUrl = (): string => { - return import.meta.env.VITE_API_BASE_URL || 'http://localhost:3000' -} - -// 获取应用配置 -export const getAppConfig = () => { - return { - title: import.meta.env.VITE_APP_TITLE || 'Hertz Admin', - version: import.meta.env.VITE_APP_VERSION || '1.0.0', - apiBaseUrl: getApiBaseUrl(), - devServerHost: import.meta.env.VITE_DEV_SERVER_HOST || 'localhost', - devServerPort: import.meta.env.VITE_DEV_SERVER_PORT || '3000', - } -} diff --git a/hertz_server_django_ui/src/utils/hertz_error_handler.ts b/hertz_server_django_ui/src/utils/hertz_error_handler.ts deleted file mode 100644 index 5865438..0000000 --- a/hertz_server_django_ui/src/utils/hertz_error_handler.ts +++ /dev/null @@ -1,375 +0,0 @@ -import { message } from 'ant-design-vue' -import { useI18n } from 'vue-i18n' - -// 错误类型枚举 -export enum ErrorType { - // 网络错误 - NETWORK_ERROR = 'NETWORK_ERROR', - TIMEOUT = 'TIMEOUT', - - // 认证错误 - UNAUTHORIZED = 'UNAUTHORIZED', - TOKEN_EXPIRED = 'TOKEN_EXPIRED', - TOKEN_INVALID = 'TOKEN_INVALID', - - // 权限错误 - FORBIDDEN = 'FORBIDDEN', - ACCESS_DENIED = 'ACCESS_DENIED', - - // 业务错误 - VALIDATION_ERROR = 'VALIDATION_ERROR', - BUSINESS_ERROR = 'BUSINESS_ERROR', - - // 系统错误 - SERVER_ERROR = 'SERVER_ERROR', - SERVICE_UNAVAILABLE = 'SERVICE_UNAVAILABLE', -} - -// 错误信息接口 -export interface ErrorInfo { - code: number - message: string - type: ErrorType - details?: any - field?: string -} - -// 错误处理器类 -export class HertzErrorHandler { - private static instance: HertzErrorHandler - private i18n: any - - constructor() { - // 在组件中使用时需要传入i18n实例 - } - - static getInstance(): HertzErrorHandler { - if (!HertzErrorHandler.instance) { - HertzErrorHandler.instance = new HertzErrorHandler() - } - return HertzErrorHandler.instance - } - - // 设置i18n实例 - setI18n(i18n: any) { - this.i18n = i18n - } - - // 获取翻译文本 - private t(key: string, fallback?: string): string { - if (this.i18n && this.i18n.t) { - return this.i18n.t(key) - } - return fallback || key - } - - // 处理HTTP错误 - handleHttpError(error: any): void { - const status = error?.response?.status - const data = error?.response?.data - - console.error('🚨 HTTP错误详情:', { - status, - data, - url: error?.config?.url, - method: error?.config?.method, - requestData: error?.config?.data - }) - - switch (status) { - case 400: - this.handleBadRequestError(data) - break - case 401: - this.handleUnauthorizedError(data) - break - case 403: - this.handleForbiddenError(data) - break - case 404: - this.handleNotFoundError(data) - break - case 422: - this.handleValidationError(data) - break - case 429: - this.handleTooManyRequestsError(data) - break - case 500: - this.handleServerError(data) - break - case 502: - case 503: - case 504: - this.handleServiceUnavailableError(data) - break - default: - this.handleUnknownError(error) - } - } - - // 处理400错误 - private handleBadRequestError(data: any): void { - const message = data?.message || data?.detail || '' - - // 检查是否是验证码相关错误 - if (this.isMessageContains(message, ['验证码', 'captcha', 'Captcha'])) { - if (this.isMessageContains(message, ['过期', 'expired', 'expire'])) { - this.showError(this.t('error.captchaExpired', '验证码已过期,请刷新后重新输入')) - } else { - this.showError(this.t('error.captchaError', '验证码错误,请重新输入(区分大小写)')) - } - return - } - - // 检查是否是用户名或密码错误 - if (this.isMessageContains(message, ['用户名', 'username', '密码', 'password', '登录', 'login'])) { - this.showError(this.t('error.loginFailed', '登录失败,请检查用户名和密码')) - return - } - - // 检查是否是注册相关错误 - if (this.isMessageContains(message, ['用户名已存在', 'username exists', 'username already'])) { - this.showError(this.t('error.usernameExists', '用户名已存在,请选择其他用户名')) - return - } - - if (this.isMessageContains(message, ['邮箱已注册', 'email exists', 'email already'])) { - this.showError(this.t('error.emailExists', '邮箱已被注册,请使用其他邮箱')) - return - } - - if (this.isMessageContains(message, ['手机号已注册', 'phone exists', 'phone already'])) { - this.showError(this.t('error.phoneExists', '手机号已被注册,请使用其他手机号')) - return - } - - // 默认400错误处理 - this.showError(data?.message || this.t('error.invalidInput', '输入数据格式不正确')) - } - - // 处理401错误 - private handleUnauthorizedError(data: any): void { - const message = data?.message || data?.detail || '' - - if (this.isMessageContains(message, ['token', 'Token', '令牌', '过期', 'expired'])) { - this.showError(this.t('error.tokenExpired', '登录已过期,请重新登录')) - // 可以在这里添加自动跳转到登录页的逻辑 - setTimeout(() => { - window.location.href = '/login' - }, 2000) - } else if (this.isMessageContains(message, ['账户锁定', 'account locked', 'locked'])) { - this.showError(this.t('error.accountLocked', '账户已被锁定,请联系管理员')) - } else if (this.isMessageContains(message, ['账户禁用', 'account disabled', 'disabled'])) { - this.showError(this.t('error.accountDisabled', '账户已被禁用,请联系管理员')) - } else { - this.showError(this.t('error.loginFailed', '登录失败,请检查用户名和密码')) - } - } - - // 处理403错误 - private handleForbiddenError(data: any): void { - const message = data?.message || data?.detail || '' - - if (this.isMessageContains(message, ['权限不足', 'permission denied', 'access denied'])) { - this.showError(this.t('error.permissionDenied', '权限不足,无法执行此操作')) - } else { - this.showError(this.t('error.accessDenied', '访问被拒绝,您没有执行此操作的权限')) - } - } - - // 处理404错误 - private handleNotFoundError(data: any): void { - const message = data?.message || data?.detail || '' - - if (this.isMessageContains(message, ['用户', 'user'])) { - this.showError(this.t('error.userNotFound', '用户不存在或已被删除')) - } else if (this.isMessageContains(message, ['部门', 'department'])) { - this.showError(this.t('error.departmentNotFound', '部门不存在或已被删除')) - } else if (this.isMessageContains(message, ['角色', 'role'])) { - this.showError(this.t('error.roleNotFound', '角色不存在或已被删除')) - } else { - this.showError(this.t('error.404', '页面未找到')) - } - } - - // 处理422验证错误 - private handleValidationError(data: any): void { - console.log('🔍 422验证错误详情:', data) - - // 处理FastAPI风格的验证错误 - if (data?.detail && Array.isArray(data.detail)) { - const errors = data.detail - const errorMessages: string[] = [] - - errors.forEach((error: any) => { - const field = error.loc?.[error.loc.length - 1] || 'unknown' - const msg = error.msg || error.message || '验证失败' - - // 根据字段和错误类型提供更具体的提示 - if (field === 'username') { - if (msg.includes('required') || msg.includes('必填')) { - errorMessages.push(this.t('error.usernameRequired', '请输入用户名')) - } else if (msg.includes('length') || msg.includes('长度')) { - errorMessages.push('用户名长度不符合要求') - } else { - errorMessages.push(`用户名: ${msg}`) - } - } else if (field === 'password') { - if (msg.includes('required') || msg.includes('必填')) { - errorMessages.push(this.t('error.passwordRequired', '请输入密码')) - } else if (msg.includes('weak') || msg.includes('强度')) { - errorMessages.push(this.t('error.passwordTooWeak', '密码强度不足,请包含大小写字母、数字和特殊字符')) - } else { - errorMessages.push(`密码: ${msg}`) - } - } else if (field === 'email') { - if (msg.includes('format') || msg.includes('格式')) { - errorMessages.push(this.t('error.emailFormatError', '邮箱格式不正确,请输入有效的邮箱地址')) - } else { - errorMessages.push(`邮箱: ${msg}`) - } - } else if (field === 'phone') { - if (msg.includes('format') || msg.includes('格式')) { - errorMessages.push(this.t('error.phoneFormatError', '手机号格式不正确,请输入11位手机号')) - } else { - errorMessages.push(`手机号: ${msg}`) - } - } else if (field === 'captcha' || field === 'captcha_code') { - errorMessages.push(this.t('error.captchaError', '验证码错误,请重新输入(区分大小写)')) - } else { - errorMessages.push(`${field}: ${msg}`) - } - }) - - if (errorMessages.length > 0) { - this.showError(errorMessages.join(';')) - return - } - } - - // 处理其他格式的验证错误 - if (data?.errors) { - const errors = data.errors - const errorMessages = [] - for (const field in errors) { - if (errors[field] && Array.isArray(errors[field])) { - errorMessages.push(`${field}: ${errors[field].join(', ')}`) - } else if (errors[field]) { - errorMessages.push(`${field}: ${errors[field]}`) - } - } - if (errorMessages.length > 0) { - this.showError(`验证失败: ${errorMessages.join('; ')}`) - return - } - } - - // 默认验证错误处理 - this.showError(data?.message || this.t('error.invalidInput', '输入数据格式不正确')) - } - - // 处理429错误(请求过多) - private handleTooManyRequestsError(data: any): void { - this.showError(this.t('error.loginAttemptsExceeded', '登录尝试次数过多,账户已被临时锁定')) - } - - // 处理500错误 - private handleServerError(data: any): void { - this.showError(this.t('error.500', '服务器内部错误,请稍后重试')) - } - - // 处理服务不可用错误 - private handleServiceUnavailableError(data: any): void { - this.showError(this.t('error.serviceUnavailable', '服务暂时不可用,请稍后重试')) - } - - // 处理网络错误 - handleNetworkError(error: any): void { - if (error?.code === 'NETWORK_ERROR' || error?.message?.includes('Network Error')) { - this.showError(this.t('error.networkError', '网络连接失败,请检查网络设置')) - } else if (error?.code === 'ECONNABORTED' || error?.message?.includes('timeout')) { - this.showError(this.t('error.timeout', '请求超时,请稍后重试')) - } else { - this.showError(this.t('error.networkError', '网络连接失败,请检查网络设置')) - } - } - - // 处理未知错误 - private handleUnknownError(error: any): void { - console.error('🚨 未知错误:', error) - this.showError(this.t('error.operationFailed', '操作失败,请稍后重试')) - } - - // 显示错误消息 - private showError(msg: string): void { - message.error(msg) - } - - // 显示成功消息 - showSuccess(msg: string): void { - message.success(msg) - } - - // 显示警告消息 - showWarning(msg: string): void { - message.warning(msg) - } - - // 检查消息是否包含指定关键词 - private isMessageContains(message: string, keywords: string[]): boolean { - if (!message) return false - const lowerMessage = message.toLowerCase() - return keywords.some(keyword => lowerMessage.includes(keyword.toLowerCase())) - } - - // 处理业务操作成功 - handleSuccess(operation: string, customMessage?: string): void { - if (customMessage) { - this.showSuccess(customMessage) - return - } - - switch (operation) { - case 'save': - this.showSuccess(this.t('error.saveSuccess', '保存成功')) - break - case 'delete': - this.showSuccess(this.t('error.deleteSuccess', '删除成功')) - break - case 'update': - this.showSuccess(this.t('error.updateSuccess', '更新成功')) - break - case 'create': - this.showSuccess('创建成功') - break - case 'login': - this.showSuccess('登录成功') - break - case 'register': - this.showSuccess('注册成功') - break - default: - this.showSuccess('操作成功') - } - } -} - -// 导出单例实例 -export const errorHandler = HertzErrorHandler.getInstance() - -// 导出便捷方法 -export const handleError = (error: any) => { - if (error?.response) { - errorHandler.handleHttpError(error) - } else if (error?.code === 'NETWORK_ERROR' || error?.message?.includes('Network Error')) { - errorHandler.handleNetworkError(error) - } else { - console.error('🚨 处理错误:', error) - errorHandler.showError('操作失败,请稍后重试') - } -} - -export const handleSuccess = (operation: string, customMessage?: string) => { - errorHandler.handleSuccess(operation, customMessage) -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/utils/hertz_permission.ts b/hertz_server_django_ui/src/utils/hertz_permission.ts deleted file mode 100644 index 95ef13a..0000000 --- a/hertz_server_django_ui/src/utils/hertz_permission.ts +++ /dev/null @@ -1,154 +0,0 @@ -/** - * 权限管理工具类 - * 统一管理用户权限检查和菜单过滤逻辑 - */ - -import { computed } from 'vue' -import { useUserStore } from '@/stores/hertz_user' -import { UserRole } from '@/router/admin_menu' - -// 权限检查接口 -export interface PermissionChecker { - hasRole(role: string): boolean - hasPermission(permission: string): boolean - hasAnyRole(roles: string[]): boolean - hasAnyPermission(permissions: string[]): boolean - isAdmin(): boolean - isLoggedIn(): boolean -} - -// 权限管理类 -export class PermissionManager implements PermissionChecker { - // 延迟获取 Pinia store,避免在 Pinia 未初始化时调用 - private get userStore() { - return useUserStore() - } - - /** - * 检查用户是否拥有指定角色 - */ - hasRole(role: string): boolean { - const userRoles = this.userStore.userInfo?.roles?.map(r => r.role_code) || [] - return userRoles.includes(role) - } - - /** - * 检查用户是否拥有指定权限 - */ - hasPermission(permission: string): boolean { - const userPermissions = this.userStore.userInfo?.permissions || [] - return userPermissions.includes(permission) - } - - /** - * 检查用户是否拥有任意一个指定角色 - */ - hasAnyRole(roles: string[]): boolean { - const userRoles = this.userStore.userInfo?.roles?.map(r => r.role_code) || [] - return roles.some(role => userRoles.includes(role)) - } - - /** - * 检查用户是否拥有任意一个指定权限 - */ - hasAnyPermission(permissions: string[]): boolean { - const userPermissions = this.userStore.userInfo?.permissions || [] - return permissions.some(permission => userPermissions.includes(permission)) - } - - /** - * 检查用户是否为管理员 - */ - isAdmin(): boolean { - const adminRoles = [UserRole.ADMIN, UserRole.SYSTEM_ADMIN, UserRole.SUPER_ADMIN] - return this.hasAnyRole(adminRoles) - } - - /** - * 检查用户是否已登录 - */ - isLoggedIn(): boolean { - return this.userStore.isLoggedIn && !!this.userStore.userInfo - } - - /** - * 获取用户角色列表 - */ - getUserRoles(): string[] { - return this.userStore.userInfo?.roles?.map(r => r.role_code) || [] - } - - /** - * 获取用户权限列表 - */ - getUserPermissions(): string[] { - return this.userStore.userInfo?.permissions || [] - } - - /** - * 检查用户是否可以访问指定路径 - */ - canAccessPath(path: string, requiredRoles?: string[], requiredPermissions?: string[]): boolean { - if (!this.isLoggedIn()) { - return false - } - - // 如果没有指定权限要求,默认允许访问 - if (!requiredRoles && !requiredPermissions) { - return true - } - - // 检查角色权限 - if (requiredRoles && requiredRoles.length > 0) { - if (!this.hasAnyRole(requiredRoles)) { - return false - } - } - - // 检查具体权限 - if (requiredPermissions && requiredPermissions.length > 0) { - if (!this.hasAnyPermission(requiredPermissions)) { - return false - } - } - - return true - } -} - -// 创建全局权限管理实例 -export const permissionManager = new PermissionManager() - -// 便捷的权限检查函数 -export const usePermission = () => { - return { - hasRole: (role: string) => permissionManager.hasRole(role), - hasPermission: (permission: string) => permissionManager.hasPermission(permission), - hasAnyRole: (roles: string[]) => permissionManager.hasAnyRole(roles), - hasAnyPermission: (permissions: string[]) => permissionManager.hasAnyPermission(permissions), - isAdmin: () => permissionManager.isAdmin(), - isLoggedIn: () => permissionManager.isLoggedIn(), - canAccessPath: (path: string, requiredRoles?: string[], requiredPermissions?: string[]) => - permissionManager.canAccessPath(path, requiredRoles, requiredPermissions) - } -} - -// Vue 3 组合式 API 权限检查 Hook -export const usePermissionCheck = () => { - const userStore = useUserStore() - - return { - // 响应式权限检查 - hasRole: (role: string) => computed(() => permissionManager.hasRole(role)), - hasPermission: (permission: string) => computed(() => permissionManager.hasPermission(permission)), - hasAnyRole: (roles: string[]) => computed(() => permissionManager.hasAnyRole(roles)), - hasAnyPermission: (permissions: string[]) => computed(() => permissionManager.hasAnyPermission(permissions)), - isAdmin: computed(() => permissionManager.isAdmin()), - isLoggedIn: computed(() => permissionManager.isLoggedIn()), - - // 用户信息 - userRoles: computed(() => permissionManager.getUserRoles()), - userPermissions: computed(() => permissionManager.getUserPermissions()), - userInfo: computed(() => userStore.userInfo) - } -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/utils/hertz_request.ts b/hertz_server_django_ui/src/utils/hertz_request.ts deleted file mode 100644 index 3d83449..0000000 --- a/hertz_server_django_ui/src/utils/hertz_request.ts +++ /dev/null @@ -1,201 +0,0 @@ -import axios from 'axios' -import type { AxiosInstance, AxiosRequestConfig, AxiosResponse, InternalAxiosRequestConfig } from 'axios' -import { handleError } from './hertz_error_handler' - -// 请求配置接口 -interface RequestConfig extends AxiosRequestConfig { - showLoading?: boolean - showError?: boolean - metadata?: { - requestId: string - timestamp: string - } -} - -// 响应数据接口 -interface ApiResponse { - code: number - message: string - data: T - success?: boolean -} - -// 请求拦截器配置 -const requestInterceptor = { - onFulfilled: (config: RequestConfig) => { - const timestamp = new Date().toISOString() - const requestId = Math.random().toString(36).substr(2, 9) - - // 简化日志,只在开发环境显示关键信息 - if (import.meta.env.DEV) { - console.log(`🚀 ${config.method?.toUpperCase()} ${config.url}`) - } - - // 添加认证token - const token = localStorage.getItem('token') - if (token) { - config.headers = config.headers || {} - config.headers.Authorization = `Bearer ${token}` - } - - // 如果是FormData,删除Content-Type让浏览器自动设置 - if (config.data instanceof FormData) { - if (config.headers && 'Content-Type' in config.headers) { - delete config.headers['Content-Type'] - } - console.log('📦 检测到FormData,移除Content-Type让浏览器自动设置') - } - - // 显示loading - if (config.showLoading !== false) { - // 这里可以添加loading显示逻辑 - } - - // 将requestId添加到config中,用于响应时匹配 - config.metadata = { requestId, timestamp } - return config as InternalAxiosRequestConfig - }, - onRejected: (error: any) => { - console.error('❌ 请求错误:', error.message) - return Promise.reject(error) - } -} - -// 响应拦截器配置 -const responseInterceptor = { - onFulfilled: (response: AxiosResponse) => { - const requestTimestamp = (response.config as any).metadata?.timestamp - const duration = requestTimestamp ? Date.now() - new Date(requestTimestamp).getTime() : 0 - - // 简化日志,只在开发环境显示关键信息 - if (import.meta.env.DEV) { - console.log(`✅ ${response.status} ${response.config.method?.toUpperCase()} ${response.config.url} (${duration}ms)`) - } - - // 统一处理响应数据 - if (response.data && typeof response.data === 'object') { - // 如果后端返回的是标准格式 {code, message, data} - if ('code' in response.data) { - // 标准API响应格式处理 - } - } - - return response - }, - onRejected: (error: any) => { - const requestTimestamp = (error.config as any)?.metadata?.timestamp - const duration = requestTimestamp ? Date.now() - new Date(requestTimestamp).getTime() : 0 - - // 简化错误日志 - console.error(`❌ ${error.response?.status || 'Network'} ${error.config?.method?.toUpperCase()} ${error.config?.url} (${duration}ms)`) - console.error('错误信息:', error.response?.data?.message || error.message) - - // 使用统一错误处理器(支持按请求关闭全局错误提示) - const showError = (error.config as any)?.showError - if (showError !== false) { - handleError(error) - } - - // 特殊处理401错误 - if (error.response?.status === 401) { - console.warn('🔒 未授权,清除token') - localStorage.removeItem('token') - // 可以在这里跳转到登录页 - } - - return Promise.reject(error) - } -} - -class HertzRequest { - private instance: AxiosInstance - - constructor(config: AxiosRequestConfig) { - // 在开发环境中使用空字符串以便Vite代理正常工作 - // 在生产环境中使用完整的API地址 - const isDev = import.meta.env.DEV - const baseURL = isDev ? '' : (import.meta.env.VITE_API_BASE_URL || 'http://localhost:8000') - console.log('🔧 创建axios实例 - isDev:', isDev) - console.log('🔧 创建axios实例 - baseURL:', baseURL) - console.log('🔧 环境变量 VITE_API_BASE_URL:', import.meta.env.VITE_API_BASE_URL) - - this.instance = axios.create({ - baseURL, - timeout: 10000, - // 不设置默认Content-Type,让每个请求根据数据类型自动设置 - ...config - }) - - // 添加请求拦截器 - this.instance.interceptors.request.use( - requestInterceptor.onFulfilled, - requestInterceptor.onRejected - ) - - // 添加响应拦截器 - this.instance.interceptors.response.use( - responseInterceptor.onFulfilled, - responseInterceptor.onRejected - ) - } - - // GET请求 - get(url: string, config?: RequestConfig): Promise { - return this.instance.get(url, config).then(res => res.data) - } - - // POST请求 - post(url: string, data?: any, config?: RequestConfig): Promise { - // 如果不是FormData,设置Content-Type为application/json - const finalConfig = { ...config } - if (!(data instanceof FormData)) { - finalConfig.headers = { - 'Content-Type': 'application/json', - ...finalConfig.headers - } - } - return this.instance.post(url, data, finalConfig).then(res => res.data) - } - - // PUT请求 - put(url: string, data?: any, config?: RequestConfig): Promise { - // 如果不是FormData,设置Content-Type为application/json - const finalConfig = { ...config } - if (!(data instanceof FormData)) { - finalConfig.headers = { - 'Content-Type': 'application/json', - ...finalConfig.headers - } - } - return this.instance.put(url, data, finalConfig).then(res => res.data) - } - - // DELETE请求 - delete(url: string, config?: RequestConfig): Promise { - return this.instance.delete(url, config).then(res => res.data) - } - - // PATCH请求 - patch(url: string, data?: any, config?: RequestConfig): Promise { - return this.instance.patch(url, data, config).then(res => res.data) - } - - // 上传文件 - upload(url: string, formData: FormData, config?: RequestConfig): Promise { - // 不要手动设置Content-Type,让浏览器自动设置,这样会包含正确的boundary - return this.instance.post(url, formData, { - ...config, - headers: { - // 不设置Content-Type,让浏览器自动设置multipart/form-data的header - ...config?.headers - } - }).then(res => res.data) - } -} - -// 创建默认实例 -export const request = new HertzRequest({}) - -// 导出类和配置接口 -export { HertzRequest } -export type { RequestConfig, ApiResponse } diff --git a/hertz_server_django_ui/src/utils/hertz_router_utils.ts b/hertz_server_django_ui/src/utils/hertz_router_utils.ts deleted file mode 100644 index f715f32..0000000 --- a/hertz_server_django_ui/src/utils/hertz_router_utils.ts +++ /dev/null @@ -1,138 +0,0 @@ -/** - * 路由工具函数 - * 用于动态路由相关的辅助功能 - */ - -// 获取views目录下的所有Vue文件 -export const getViewFiles = () => { - const viewsContext = import.meta.glob('@/views/*.vue') - return Object.keys(viewsContext).map(path => path.split('/').pop()) -} - -// 从文件名生成路由名称 -export const generateRouteName = (fileName: string): string => { - return fileName.replace('.vue', '') -} - -// 从文件名生成路由路径 -export const generateRoutePath = (fileName: string): string => { - const routeName = generateRouteName(fileName) - let routePath = `/${routeName.toLowerCase()}` - - // 处理特殊命名(驼峰转短横线) - if (routeName !== routeName.toLowerCase()) { - routePath = `/${routeName.replace(/([A-Z])/g, '-$1').toLowerCase().replace(/^-/, '')}` - } - - return routePath -} - -// 生成路由标题 -export const generateRouteTitle = (routeName: string): string => { - const titleMap: Record = { - Dashboard: '仪表板', - User: '用户管理', - Profile: '个人资料', - Settings: '系统设置', - Test: '样式测试', - WebSocketTest: 'WebSocket测试', - NotFound: '页面未找到', - } - - return titleMap[routeName] || routeName -} - -// 判断路由是否需要认证 -export const shouldRequireAuth = (routeName: string): boolean => { - const publicRoutes = ['Test', 'WebSocketTest'] - return !( - publicRoutes.includes(routeName) || // 公开路由列表 - routeName.startsWith('Demo') // Demo开头的页面不需要认证 - ) -} - -// 获取公开路由列表 -export const getPublicRoutes = (): string[] => { - return ['Test', 'WebSocketTest', 'Demo'] // 可以添加更多公开路由 -} - -// 打印路由调试信息 -export const debugRoutes = () => { - const viewFiles = getViewFiles() - const fixedFiles = ['Home.vue', 'Login.vue'] - const dynamicFiles = viewFiles.filter(file => !fixedFiles.includes(file) && file !== 'NotFound.vue') - - console.log('🔍 路由调试信息:') - console.log('📁 所有视图文件:', viewFiles) - console.log('🔒 固定路由文件:', fixedFiles) - console.log('🚀 动态路由文件:', dynamicFiles) - - const publicRoutes = getPublicRoutes() - console.log('🔓 公开路由 (不需要认证):', publicRoutes) - - console.log('\n📋 动态路由配置:') - dynamicFiles.forEach(file => { - const routeName = generateRouteName(file) - const routePath = generateRoutePath(file) - const title = generateRouteTitle(routeName) - const requiresAuth = shouldRequireAuth(routeName) - const isPublic = !requiresAuth - - console.log(` ${file} → ${routePath} (${title}) ${isPublic ? '🔓' : '🔒'}`) - }) - - console.log('\n🎯 Demo页面特殊说明:') - console.log(' - Demo开头的页面不需要认证 (Demo.vue, DemoPage.vue等)') - console.log(' - 可以直接访问 /demo 路径') -} - -// 在开发环境中自动调用调试函数 -if (import.meta.env.DEV) { - debugRoutes() -} - -// 提供全局访问的路由信息查看函数 -export const showRoutesInfo = () => { - console.log('🚀 Hertz Admin 路由配置信息:') - console.log('📋 完整路由列表:') - - // 注意: 这里需要从路由实例中获取真实数据 - // 由于路由工具函数在路由配置之前加载,这里提供的是示例数据 - // 实际的动态路由信息会在项目启动时通过logRouteInfo()函数显示 - - console.log('\n🔒 固定路由 (需要手动配置):') - console.log(' 🔒 / → Home (首页)') - console.log(' 🔓 /login → Login (登录)') - - console.log('\n🚀 动态路由 (自动生成):') - console.log(' 🔒 /dashboard → Dashboard (仪表板)') - console.log(' 🔒 /user → User (用户管理)') - console.log(' 🔒 /profile → Profile (个人资料)') - console.log(' 🔒 /settings → Settings (系统设置)') - console.log(' 🔓 /test → Test (样式测试)') - console.log(' 🔓 /websocket-test → WebSocketTest (WebSocket测试)') - console.log(' 🔓 /demo → Demo (动态路由演示)') - - console.log('\n❓ 404路由:') - console.log(' ❓ /:pathMatch(.*)* → NotFound (页面未找到)') - - console.log('\n📖 访问说明:') - console.log(' 🔓 公开路由: 可以直接访问,不需要登录') - console.log(' 🔒 私有路由: 需要登录后才能访问') - console.log(' 💡 提示: 可以在浏览器中直接访问这些路径') - - console.log('\n🌐 可用链接:') - console.log(' http://localhost:3000/ - 首页 (需要登录)') - console.log(' http://localhost:3000/login - 登录页面') - console.log(' http://localhost:3000/dashboard - 仪表板 (需要登录)') - console.log(' http://localhost:3000/user - 用户管理 (需要登录)') - console.log(' http://localhost:3000/profile - 个人资料 (需要登录)') - console.log(' http://localhost:3000/settings - 系统设置 (需要登录)') - console.log(' http://localhost:3000/test - 样式测试 (公开)') - console.log(' http://localhost:3000/websocket-test - WebSocket测试 (公开)') - console.log(' http://localhost:3000/demo - 动态路由演示 (公开)') - console.log(' http://localhost:3000/any-other-path - 404页面 (公开)') - - console.log('\n✅ 路由配置加载完成!') - console.log('💡 提示: 启动项目后会在控制台看到真正的动态路由信息') -} diff --git a/hertz_server_django_ui/src/utils/hertz_url.ts b/hertz_server_django_ui/src/utils/hertz_url.ts deleted file mode 100644 index fc077b6..0000000 --- a/hertz_server_django_ui/src/utils/hertz_url.ts +++ /dev/null @@ -1,128 +0,0 @@ -/** - * URL处理工具函数 - */ - -/** - * 获取完整的文件URL - * @param relativePath 相对路径,如 /media/detection/original/xxx.jpg - * @returns 完整的URL - */ -export function getFullFileUrl(relativePath: string): string { - if (!relativePath) { - console.warn('⚠️ 文件路径为空') - return '' - } - - // 如果已经是完整URL,直接返回 - if (relativePath.startsWith('http://') || relativePath.startsWith('https://')) { - return relativePath - } - - // 在开发环境中,使用相对路径(通过Vite代理) - if (import.meta.env.DEV) { - return relativePath - } - - // 在生产环境中,拼接完整的URL - const baseURL = getBackendBaseUrl() - return `${baseURL}${relativePath}` -} - -export function getBackendBaseUrl(): string { - return import.meta.env.VITE_API_BASE_URL || 'http://localhost:3000' -} - -export function getWsBaseUrl(): string { - const httpBase = getBackendBaseUrl() - if (httpBase.startsWith('https://')) { - return 'wss://' + httpBase.slice('https://'.length) - } - if (httpBase.startsWith('http://')) { - return 'ws://' + httpBase.slice('http://'.length) - } - return httpBase -} - -/** - * 获取API基础URL - * @returns API基础URL - */ -export function getApiBaseUrl(): string { - if (import.meta.env.DEV) { - return '' // 开发环境使用空字符串,通过Vite代理 - } - return getBackendBaseUrl() -} - -/** - * 获取媒体文件基础URL - * @returns 媒体文件基础URL - */ -export function getMediaBaseUrl(): string { - if (import.meta.env.DEV) { - return '' // 开发环境使用空字符串,通过Vite代理 - } - const baseURL = getBackendBaseUrl() - return baseURL.replace('/api', '') // 移除/api后缀 -} - -/** - * 检查URL是否可访问 - * @param url 要检查的URL - * @returns Promise - */ -export async function checkUrlAccessibility(url: string): Promise { - try { - const response = await fetch(url, { method: 'HEAD' }) - return response.ok - } catch (error) { - console.error('❌ URL访问检查失败:', url, error) - return false - } -} - -/** - * 格式化文件大小 - * @param bytes 字节数 - * @returns 格式化后的文件大小 - */ -export function formatFileSize(bytes: number): string { - if (bytes === 0) return '0 Bytes' - - const k = 1024 - const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB'] - const i = Math.floor(Math.log(bytes) / Math.log(k)) - - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i] -} - -/** - * 获取文件扩展名 - * @param filename 文件名 - * @returns 文件扩展名 - */ -export function getFileExtension(filename: string): string { - return filename.split('.').pop()?.toLowerCase() || '' -} - -/** - * 检查是否为图片文件 - * @param filename 文件名或URL - * @returns 是否为图片文件 - */ -export function isImageFile(filename: string): boolean { - const imageExtensions = ['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg'] - const extension = getFileExtension(filename) - return imageExtensions.includes(extension) -} - -/** - * 检查是否为视频文件 - * @param filename 文件名或URL - * @returns 是否为视频文件 - */ -export function isVideoFile(filename: string): boolean { - const videoExtensions = ['mp4', 'avi', 'mov', 'wmv', 'flv', 'webm', 'mkv'] - const extension = getFileExtension(filename) - return videoExtensions.includes(extension) -} diff --git a/hertz_server_django_ui/src/utils/hertz_utils.ts b/hertz_server_django_ui/src/utils/hertz_utils.ts deleted file mode 100644 index e051aaf..0000000 --- a/hertz_server_django_ui/src/utils/hertz_utils.ts +++ /dev/null @@ -1,251 +0,0 @@ -import { useAppStore } from '@/stores/hertz_app' - -// 日期格式化 -export const formatDate = (date: Date | string | number, format = 'YYYY-MM-DD HH:mm:ss') => { - const d = new Date(date) - - if (isNaN(d.getTime())) { - return '' - } - - const year = d.getFullYear() - const month = String(d.getMonth() + 1).padStart(2, '0') - const day = String(d.getDate()).padStart(2, '0') - const hours = String(d.getHours()).padStart(2, '0') - const minutes = String(d.getMinutes()).padStart(2, '0') - const seconds = String(d.getSeconds()).padStart(2, '0') - - return format - .replace('YYYY', year.toString()) - .replace('MM', month) - .replace('DD', day) - .replace('HH', hours) - .replace('mm', minutes) - .replace('ss', seconds) -} - -// 防抖函数 -export const debounce = any>( - func: T, - delay: number -): ((...args: Parameters) => void) => { - let timeoutId: NodeJS.Timeout - - return (...args: Parameters) => { - clearTimeout(timeoutId) - timeoutId = setTimeout(() => func(...args), delay) - } -} - -// 节流函数 -export const throttle = any>( - func: T, - delay: number -): ((...args: Parameters) => void) => { - let lastCall = 0 - - return (...args: Parameters) => { - const now = Date.now() - - if (now - lastCall >= delay) { - lastCall = now - func(...args) - } - } -} - -// 深拷贝 -export const deepClone = (obj: T): T => { - if (obj === null || typeof obj !== 'object') { - return obj - } - - if (obj instanceof Date) { - return new Date(obj.getTime()) as T - } - - if (obj instanceof Array) { - return obj.map(item => deepClone(item)) as T - } - - if (typeof obj === 'object') { - const cloned = {} as T - for (const key in obj) { - if (obj.hasOwnProperty(key)) { - cloned[key] = deepClone(obj[key]) - } - } - return cloned - } - - return obj -} - -// 数组去重 -export const unique = (arr: T[]): T[] => { - return Array.from(new Set(arr)) -} - -// 获取URL参数 -export const getUrlParam = (name: string, url?: string): string | null => { - const searchUrl = url || window.location.search - const params = new URLSearchParams(searchUrl) - return params.get(name) -} - -// 设置URL参数 -export const setUrlParam = (name: string, value: string, url?: string): string => { - const searchUrl = url || window.location.search - const params = new URLSearchParams(searchUrl) - - if (value === null || value === undefined || value === '') { - params.delete(name) - } else { - params.set(name, value) - } - - return params.toString() -} - -// 复制到剪贴板 -export const copyToClipboard = async (text: string): Promise => { - try { - if (navigator.clipboard && window.isSecureContext) { - await navigator.clipboard.writeText(text) - } else { - // 降级处理 - const textArea = document.createElement('textarea') - textArea.value = text - textArea.style.position = 'fixed' - textArea.style.left = '-999999px' - textArea.style.top = '-999999px' - document.body.appendChild(textArea) - textArea.focus() - textArea.select() - - const successful = document.execCommand('copy') - textArea.remove() - - if (!successful) { - throw new Error('复制失败') - } - } - return true - } catch (error) { - console.error('复制失败:', error) - return false - } -} - -// 下载文件 -export const downloadFile = (url: string, filename?: string) => { - const link = document.createElement('a') - link.href = url - link.download = filename || '' - link.style.display = 'none' - document.body.appendChild(link) - link.click() - document.body.removeChild(link) -} - -// 格式化文件大小 -export const formatFileSize = (bytes: number): string => { - if (bytes === 0) return '0 B' - - const k = 1024 - const sizes = ['B', 'KB', 'MB', 'GB', 'TB'] - const i = Math.floor(Math.log(bytes) / Math.log(k)) - - return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i] -} - -// 验证邮箱格式 -export const isValidEmail = (email: string): boolean => { - const emailRegex = /^[^\s@]+@[^\s@]+\.[^\s@]+$/ - return emailRegex.test(email) -} - -// 验证手机号格式(中国大陆) -export const isValidPhone = (phone: string): boolean => { - const phoneRegex = /^1[3-9]\d{9}$/ - return phoneRegex.test(phone) -} - -// 验证身份证号 -export const isValidIdCard = (idCard: string): boolean => { - const idCardRegex = /(^\d{15}$)|(^\d{18}$)|(^\d{17}(\d|X|x)$)/ - return idCardRegex.test(idCard) -} - -// 生成随机字符串 -export const generateRandomString = (length: number = 8): string => { - const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' - let result = '' - - for (let i = 0; i < length; i++) { - result += chars.charAt(Math.floor(Math.random() * chars.length)) - } - - return result -} - -// 等待函数 -export const sleep = (ms: number): Promise => { - return new Promise(resolve => setTimeout(resolve, ms)) -} - -// 获取浏览器信息 -export const getBrowserInfo = () => { - const userAgent = navigator.userAgent - const isChrome = /Chrome/.test(userAgent) && /Google Inc/.test(navigator.vendor) - const isFirefox = /Firefox/.test(userAgent) - const isSafari = /Safari/.test(userAgent) && /Apple Computer/.test(navigator.vendor) - const isEdge = /Edg/.test(userAgent) - const isIE = /MSIE|Trident/.test(userAgent) - - return { - isChrome, - isFirefox, - isSafari, - isEdge, - isIE, - userAgent, - } -} - -// 本地存储封装 -export const storage = { - get: (key: string, defaultValue?: T): T | null => { - try { - const item = localStorage.getItem(key) - return item ? JSON.parse(item) : (defaultValue ?? null) - } catch (error) { - console.error(`获取本地存储失败 (${key}):`, error) - return defaultValue ?? null - } - }, - - set: (key: string, value: T): void => { - try { - localStorage.setItem(key, JSON.stringify(value)) - } catch (error) { - console.error(`设置本地存储失败 (${key}):`, error) - } - }, - - remove: (key: string): void => { - try { - localStorage.removeItem(key) - } catch (error) { - console.error(`删除本地存储失败 (${key}):`, error) - } - }, - - clear: (): void => { - try { - localStorage.clear() - } catch (error) { - console.error('清空本地存储失败:', error) - } - }, -} diff --git a/hertz_server_django_ui/src/utils/menu_mapping.ts b/hertz_server_django_ui/src/utils/menu_mapping.ts deleted file mode 100644 index 60210ca..0000000 --- a/hertz_server_django_ui/src/utils/menu_mapping.ts +++ /dev/null @@ -1,112 +0,0 @@ -import { menuApi, type Menu } from '@/api/menu' - -// 菜单key和菜单ID的映射关系 -let menuKeyToIdMap: Map = new Map() -let menuIdToKeyMap: Map = new Map() -let isInitialized = false - -// 菜单key和菜单code的映射关系(用于建立映射) -const MENU_KEY_TO_CODE_MAP: { [key: string]: string } = { - 'dashboard': 'dashboard', - 'user-management': 'user_management', - 'department-management': 'department_management', - 'menu-management': 'menu_management', - 'teacher': 'role_management' -} - -/** - * 初始化菜单映射 - */ -export const initializeMenuMapping = async (): Promise => { - try { - // 获取菜单树数据 - const response = await menuApi.getMenuTree() - - if (response.code === 200 && response.data) { - // 清空现有映射 - menuKeyToIdMap.clear() - - // 递归处理菜单树 - const processMenuTree = (menus: Menu[]) => { - menus.forEach(menu => { - if (menu.key && menu.id) { - menuKeyToIdMap.set(menu.key, menu.id) - } - - // 递归处理子菜单 - if (menu.children && menu.children.length > 0) { - processMenuTree(menu.children) - } - }) - } - - processMenuTree(response.data) - } - } catch (error) { - console.error('初始化菜单映射时发生错误:', error) - } -} - -/** - * 递归构建菜单映射关系 - */ -const buildMenuMapping = (menus: Menu[]): void => { - menus.forEach(menu => { - // 根据menu_code找到对应的key - const menuKey = Object.keys(MENU_KEY_TO_CODE_MAP).find( - key => MENU_KEY_TO_CODE_MAP[key] === menu.menu_code - ) - - if (menuKey) { - menuKeyToIdMap.set(menuKey, menu.menu_id) - menuIdToKeyMap.set(menu.menu_id, menuKey) - } - - // 递归处理子菜单 - if (menu.children && menu.children.length > 0) { - buildMenuMapping(menu.children) - } - }) -} - -/** - * 根据菜单key获取菜单ID - */ -export const getMenuIdByKey = (menuKey: string): number | undefined => { - return menuKeyToIdMap.get(menuKey) -} - -/** - * 根据菜单ID获取菜单key - */ -export const getMenuKeyById = (menuId: number): string | undefined => { - return menuIdToKeyMap.get(menuId) -} - -/** - * 检查用户是否有指定菜单的权限 - */ -export const hasMenuPermissionById = (menuKey: string, userMenuPermissions: number[]): boolean => { - const menuId = getMenuIdByKey(menuKey) - - if (!menuId) { - // 降级策略:如果没有找到菜单映射,则允许显示(向后兼容) - return true - } - - return userMenuPermissions.includes(menuId) -} - -/** - * 获取用户有权限的菜单keys - */ -export const getPermittedMenuKeys = (userMenuPermissions: number[]): string[] => { - const permittedKeys: string[] = [] - userMenuPermissions.forEach(menuId => { - const menuKey = getMenuKeyById(menuId) - if (menuKey) { - permittedKeys.push(menuKey) - } - }) - return permittedKeys -} \ No newline at end of file diff --git a/hertz_server_django_ui/src/utils/yolo_frontend.ts b/hertz_server_django_ui/src/utils/yolo_frontend.ts deleted file mode 100644 index 2eb19ed..0000000 --- a/hertz_server_django_ui/src/utils/yolo_frontend.ts +++ /dev/null @@ -1,730 +0,0 @@ -// 前端ONNX YOLO检测工具类 -import * as ort from 'onnxruntime-web' - -// ONNX检测结果接口 -export interface YOLODetectionResult { - detections: Array<{ - class_name: string - confidence: number - bbox: { - x: number - y: number - width: number - height: number - } - }> - object_count: number - detected_categories: string[] - confidence_scores: number[] - avg_confidence: number - annotated_image: string // base64图像 - processing_time: number -} - -// 不预置任何类别名称;等待后端或标签文件提供 - -class YOLODetector { - private session: ort.InferenceSession | null = null - private modelPath: string = '' - private classNames: string[] = [] - private inputShape: [number, number] = [640, 640] // 默认输入尺寸(可在 WASM 下动态调小) - private currentEP: 'webgpu' | 'webgl' | 'wasm' = 'wasm' - - /** - * 加载ONNX模型 - * @param modelPath 模型路径(相对于public目录) - * @param classNames 类别名称列表(可选,如果不提供则使用默认COCO类别) - */ - async loadModel(modelPath: string, classNames?: string[], forceEP?: 'webgpu' | 'webgl' | 'wasm'): Promise { - try { - console.log('🔄 开始加载ONNX模型:', modelPath) - - // 设置类别名称 - if (classNames && classNames.length > 0) { - this.classNames = classNames - console.log('📦 使用自定义类别:', classNames.length, '个类别') - } else { - // 如果未提供类别,稍后根据输出维度自动推断数量并用 class_0.. 命名 - this.classNames = [] - console.log('📦 未提供类别,将根据模型输出自动推断类别数量') - } - - // 动态选择可用的 wasm 资源路径,避免 404/HTML 导致的“magic word”错误 - const ensureWasmPath = async () => { - const candidates = [ - 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.2/dist/', - 'https://unpkg.com/onnxruntime-web@1.23.2/dist/', - '/onnxruntime-web/', // 如果你把 dist 拷贝到 public/onnxruntime-web/ - '/ort/' // 或者 public/ort/ - ] - for (const base of candidates) { - try { - const testUrl = base.replace(/\/$/, '') + '/ort-wasm.wasm' - const res = await fetch(testUrl, { method: 'HEAD', cache: 'no-store', mode: 'no-cors' as any }) - // no-cors 模式下 status 为 0,也视为可用(跨域但可下载) - if (res.ok || res.status === 0) { - // @ts-ignore - ort.env.wasm.wasmPaths = base - return true - } - } catch {} - } - return false - } - - await ensureWasmPath() - - // 配置 WASM 线程:若不支持跨域隔离/SharedArrayBuffer,则退回单线程,避免“worker not ready” - const canMultiThread = (self as any).crossOriginIsolated && typeof (self as any).SharedArrayBuffer !== 'undefined' - try { - // @ts-ignore - ort.env.wasm.numThreads = canMultiThread ? Math.max(2, Math.min(4, (navigator as any)?.hardwareConcurrency || 2)) : 1 - // @ts-ignore - ort.env.wasm.proxy = true - } catch {} - - const createWithEP = async (ep: 'webgpu' | 'webgl' | 'wasm') => { - if (ep === 'webgpu') { - const prefer: ort.InferenceSession.SessionOptions = { - executionProviders: ['webgpu'], - graphOptimizationLevel: 'all', - } - this.session = await ort.InferenceSession.create(modelPath, prefer) - this.currentEP = 'webgpu' - return - } - if (ep === 'webgl') { - const prefer: ort.InferenceSession.SessionOptions = { - executionProviders: ['webgl'], - graphOptimizationLevel: 'all', - } - this.session = await ort.InferenceSession.create(modelPath, prefer) - this.currentEP = 'webgl' - return - } - // wasm - const prefer: ort.InferenceSession.SessionOptions = { - executionProviders: ['wasm'], - graphOptimizationLevel: 'all', - } - this.session = await ort.InferenceSession.create(modelPath, prefer) - this.currentEP = 'wasm' - } - - // 配置 ONNX Runtime:优先 GPU(WebGPU/WebGL),再回退 WASM - // 支持通过 localStorage 开关强制使用 WASM:localStorage.setItem('ort_force_wasm','1') - // 也可通过第三个参数 forceEP 指定(用于错误时的程序化降级) - const forceWasm = forceEP === 'wasm' || (localStorage.getItem('ort_force_wasm') === '1') - // 1) WebGPU(实验性,浏览器需支持 navigator.gpu) - let created = false - if (!forceWasm && (navigator as any)?.gpu && (!forceEP || forceEP === 'webgpu')) { - try { - // 动态引入 webgpu 版本(若不支持不会打包) - await import('onnxruntime-web/webgpu') - await createWithEP('webgpu') - created = true - console.log('✅ 使用 WebGPU 推理') - } catch (e) { - console.warn('⚠️ WebGPU 初始化失败,回退到 WebGL:', e) - } - } - - // 2) WebGL(GPU 加速,兼容更好) - if (!forceWasm && !created && (!forceEP || forceEP === 'webgl')) { - try { - await createWithEP('webgl') - created = true - console.log('✅ 使用 WebGL 推理') - } catch (e2) { - console.warn('⚠️ WebGL 初始化失败,回退到 WASM:', e2) - } - } - - // 3) WASM(CPU) - if (!created) { - try { - // 设置 WASM 线程/特性(路径已在 ensureWasmPath 中选择) - // @ts-ignore - ort.env.wasm.numThreads = Math.max(1, Math.min(4, (navigator as any)?.hardwareConcurrency || 2)) - // @ts-ignore - ort.env.wasm.proxy = true - } catch {} - - await createWithEP('wasm') - console.log('✅ 使用 WASM 推理') - } - this.modelPath = modelPath - - // 根据后端动态调整输入尺寸:WASM 默认调小以提升流畅度,可用 localStorage 覆盖 - try { - const override = parseInt(localStorage.getItem('ort_input_size') || '', 10) - if (Number.isFinite(override) && override >= 256 && override <= 1024) { - this.inputShape = [override, override] as any - } else if (this.currentEP === 'wasm') { - this.inputShape = [512, 512] as any - } else { - this.inputShape = [640, 640] as any - } - console.log('🧩 推理输入尺寸:', this.inputShape[0]) - } catch {} - - // 获取模型输入输出信息(兼容性更强的写法) - const inputNames = this.session.inputNames - const outputNames = this.session.outputNames - console.log('✅ 模型加载成功') - console.log('📥 输入:', inputNames) - console.log('📤 输出:', outputNames) - - // 尝试从 outputMetadata 推断类别数(某些环境不提供 dims,需要兜底) - try { - if (outputNames && outputNames.length > 0) { - const outputMetadata: any = (this.session as any).outputMetadata - const outputName = outputNames[0] - const meta = outputMetadata?.[outputName] - const outputShape: number[] | undefined = meta?.dims - if (Array.isArray(outputShape) && outputShape.length >= 3) { - const numClasses = (outputShape[2] as number) - 5 // YOLO: [N, B, 5+C] - if (Number.isFinite(numClasses) && numClasses > 0 && numClasses !== this.classNames.length) { - console.warn(`⚠️ 模型输出类别数 (${numClasses}) 与提供的类别数 (${this.classNames.length}) 不匹配/或未提供`) - if (this.classNames.length === 0) { - this.classNames = Array.from({ length: numClasses }, (_, i) => `class_${i}`) - console.log('📦 根据模型输出调整类别数量为:', numClasses) - } - } - } else { - console.warn('⚠️ 无法从 outputMetadata 推断输出维度,将在首次推理时根据输出tensor推断。') - } - } - } catch (metaErr) { - console.warn('⚠️ 读取 outputMetadata 失败,将在首次推理时推断类别数。', metaErr) - } - } catch (error) { - console.error('❌ 加载模型失败:', error) - throw new Error(`加载ONNX模型失败: ${error instanceof Error ? error.message : '未知错误'}`) - } - } - - /** - * 检查模型是否已加载 - */ - isLoaded(): boolean { - return this.session !== null - } - - /** - * 获取当前加载的模型路径 - */ - getModelPath(): string { - return this.modelPath - } - - /** - * 获取类别名称列表 - */ - getClassNames(): string[] { - return this.classNames - } - - /** - * 预处理图像(Ultralytics letterbox:保比例缩放+灰边填充) - * 返回输入张量与还原坐标所需的比例与padding - */ - private preprocessImage(image: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement): { - input: Float32Array - ratio: number - padX: number - padY: number - dstW: number - dstH: number - srcW: number - srcH: number - } { - const dstW = this.inputShape[0] - const dstH = this.inputShape[1] - const canvas = document.createElement('canvas') - const ctx = canvas.getContext('2d') - if (!ctx) throw new Error('无法创建canvas上下文') - - const srcW = image instanceof HTMLVideoElement ? image.videoWidth : (image as HTMLImageElement | HTMLCanvasElement).width - const srcH = image instanceof HTMLVideoElement ? image.videoHeight : (image as HTMLImageElement | HTMLCanvasElement).height - - // 计算 letterbox - const r = Math.min(dstW / srcW, dstH / srcH) - const newW = Math.round(srcW * r) - const newH = Math.round(srcH * r) - const padX = Math.floor((dstW - newW) / 2) - const padY = Math.floor((dstH - newH) / 2) - - canvas.width = dstW - canvas.height = dstH - // 背景填充灰色(114)与 Ultralytics 一致 - ctx.fillStyle = 'rgb(114,114,114)' - ctx.fillRect(0, 0, dstW, dstH) - // 绘制等比缩放后的图像到中间 - ctx.drawImage(image as any, 0, 0, srcW, srcH, padX, padY, newW, newH) - - const imageData = ctx.getImageData(0, 0, dstW, dstH) - const data = imageData.data - - const input = new Float32Array(3 * dstW * dstH) - for (let i = 0; i < data.length; i += 4) { - const r8 = data[i] / 255.0 - const g8 = data[i + 1] / 255.0 - const b8 = data[i + 2] / 255.0 - const idx = i / 4 - input[idx] = r8 - input[idx + dstW * dstH] = g8 - input[idx + dstW * dstH * 2] = b8 - } - - return { input, ratio: r, padX, padY, dstW, dstH, srcW, srcH } - } - - /** - * 非极大值抑制(NMS) - */ - private nms(boxes: Array<{x: number, y: number, w: number, h: number, conf: number, class: number}>, iouThreshold: number): number[] { - if (boxes.length === 0) return [] - - // 按置信度排序 - boxes.sort((a, b) => b.conf - a.conf) - - const selected: number[] = [] - const suppressed = new Set() - - for (let i = 0; i iouThreshold) { - suppressed.add(j) - } - } - } - - return selected - } - - /** - * 计算IoU(交并比) - */ - private calculateIoU(box1: {x: number, y: number, w: number, h: number}, box2: {x: number, y: number, w: number, h: number}): number { - const x1 = Math.max(box1.x, box2.x) - const y1 = Math.max(box1.y, box2.y) - const x2 = Math.min(box1.x + box1.w, box2.x + box2.w) - const y2 = Math.min(box1.y + box1.h, box2.y + box2.h) - - if (x2 < x1 || y2 < y1) return 0 - - const intersection = (x2 - x1) * (y2 - y1) - const area1 = box1.w * box1.h - const area2 = box2.w * box2.h - const union = area1 + area2 - intersection - - return intersection / union - } - - /** - * 后处理检测结果 - */ - private postprocess( - output: ort.Tensor, - meta: { ratio: number; padX: number; padY: number; srcW: number; srcH: number }, - confThreshold: number, - nmsThreshold: number, - opts?: { maxDetections?: number; minBoxArea?: number; classWise?: boolean } - ): Array<{class_name: string, confidence: number, bbox: {x: number, y: number, width: number, height: number}}> { - const outputData = output.data as Float32Array - const outputShape = output.dims || [] - - // YOLO输出常见两种: - // A) [1, num_boxes, 5+num_classes] - // B) [1, 5+num_classes, num_boxes] - // 另外也可能已经扁平化为 [num_boxes, 5+num_classes] - let numBoxes = 0 - let numFeatures = 0 - if (outputShape.length === 3) { - // 取更大的作为 boxes 维度(通常是 8400),较小的是 5+C(通常是 85) - const a = outputShape[1] as number - const b = outputShape[2] as number - if (a >= b) { - numBoxes = a - numFeatures = b - } else { - numBoxes = b - numFeatures = a - } - } else if (outputShape.length === 2) { - numBoxes = outputShape[0] as number - numFeatures = outputShape[1] as number - } else { - // 无维度信息时根据长度推断(保底) - numFeatures = 85 - numBoxes = Math.floor(outputData.length / numFeatures) - } - const numClasses = Math.max(0, numFeatures - 5) - - const detections: Array<{x: number, y: number, w: number, h: number, conf: number, class: number}> = [] - - // 还原到原图坐标:先减去 padding,再除以 ratio - const { ratio, padX, padY, srcW: originalWidth, srcH: originalHeight } = meta - - // 获取 (row i, col j) 的值,兼容布局 A/B - const getVal = (i: number, j: number): number => { - if (outputShape.length === 3) { - const a = outputShape[1] as number - const b = outputShape[2] as number - if (a >= b) { - // [1, boxes, features] - return outputData[i * b + j] - } - // [1, features, boxes] - return outputData[j * a + i] - } - // [boxes, features] - return outputData[i * numFeatures + j] - } - - // sigmoid - const sigmoid = (v: number) => 1 / (1 + Math.exp(-v)) - - // 情况一:部分导出的ONNX已经做过NMS,输出形如 [num, 6]:x1,y1,x2,y2,score,classId(或其它顺序)。 - const tryPostNmsLayouts = () => { - const candidates: Array<(row: (j:number)=>number) => {x:number,y:number,w:number,h:number,conf:number,cls:number} | null> = [ - // [x1,y1,x2,y2,score,cls] - (get) => { - const x1 = get(0), y1 = get(1), x2 = get(2), y2 = get(3) - const score = get(4), cls = get(5) - if (!isFinite(x1+y1+x2+y2+score+cls)) return null - if (score < 0 || score > 1) return null - const w = Math.abs(x2 - x1), h = Math.abs(y2 - y1) - return { x: Math.min(x1,x2), y: Math.min(y1,y2), w, h, conf: score, cls: Math.max(0, Math.floor(cls)) } - }, - // [cls,score,x1,y1,x2,y2] - (get) => { - const cls = get(0), score = get(1), x1 = get(2), y1 = get(3), x2 = get(4), y2 = get(5) - if (!isFinite(x1+y1+x2+y2+score+cls)) return null - if (score < 0 || score > 1) return null - const w = Math.abs(x2 - x1), h = Math.abs(y2 - y1) - return { x: Math.min(x1,x2), y: Math.min(y1,y2), w, h, conf: score, cls: Math.max(0, Math.floor(cls)) } - }, - // [x,y,w,h,score,cls](xywh) - (get) => { - const x = get(0), y = get(1), w = get(2), h = get(3), score = get(4), cls = get(5) - if (!isFinite(x+y+w+h+score+cls)) return null - if (score < 0 || score > 1) return null - return { x: x - w/2, y: y - h/2, w, h, conf: score, cls: Math.max(0, Math.floor(cls)) } - } - ] - const out: typeof detections = [] - for (let i = 0; i < numBoxes; i++) { - const getter = (j:number) => getVal(i, j) - let picked = null - for (const decode of candidates) { - picked = decode(getter) - if (picked && picked.conf >= confThreshold) break - } - if (!picked || picked.conf < confThreshold) continue - // 还原坐标 - let { x, y, w, h, conf, cls } = picked - x = (x - padX) / ratio - y = (y - padY) / ratio - w = w / ratio - h = h / ratio - const area = Math.max(0, w) * Math.max(0, h) - const minArea = opts?.minBoxArea ?? (meta.srcW * meta.srcH * 0.0001) - if (area <= 0 || area < minArea) continue - out.push({ x, y, w, h, conf, class: cls }) - } - return out - } - - // 情况二:原始预测 [*, *, 5+num_classes],需要 obj × class 计算。 - // 支持两种坐标格式:xywh(中心点) 与 xyxy(左上/右下)。优先取能得到更多有效框的解码。 - const decode = (mode: 'xywh' | 'xyxy') => { - const out: typeof detections = [] - for (let i = 0; i < numBoxes; i++) { - const v0 = getVal(i, 0) - const v1 = getVal(i, 1) - const v2 = getVal(i, 2) - const v3 = getVal(i, 3) - const objConf = sigmoid(getVal(i, 4)) - - // 最大类别 - let maxClassConf = 0 - let maxClassIdx = 0 - for (let j = 0; j < numClasses; j++) { - const classConf = sigmoid(getVal(i, 5 + j)) - if (classConf > maxClassConf) { - maxClassConf = classConf - maxClassIdx = j - } - } - const confidence = objConf * maxClassConf - if (confidence < confThreshold) continue - - let x = 0, y = 0, w = 0, h = 0 - if (mode === 'xywh') { - const xc = (v0 - padX) / ratio - const yc = (v1 - padY) / ratio - const wv = v2 / ratio - const hv = v3 / ratio - x = xc - wv / 2 - y = yc - hv / 2 - w = wv - h = hv - } else { - // xyxy - const x1 = (v0 - padX) / ratio - const y1 = (v1 - padY) / ratio - const x2 = (v2 - padX) / ratio - const y2 = (v3 - padY) / ratio - x = Math.min(x1, x2) - y = Math.min(y1, y2) - w = Math.abs(x2 - x1) - h = Math.abs(y2 - y1) - } - const area = Math.max(0, w) * Math.max(0, h) - const minArea = opts?.minBoxArea ?? (originalWidth * originalHeight * 0.00005) // 放宽:0.005% - if (area <= 0 || area < minArea) continue - if (w > 4 * originalWidth || h > 4 * originalHeight) continue // 明显异常 - - out.push({ x, y, w, h, conf: confidence, class: maxClassIdx }) - } - return out - } - - let pick: typeof detections = [] - // 若特征维很小(<=6),优先按“已NMS格式”解析 - if (numFeatures <= 6) { - pick = tryPostNmsLayouts() - } - // 否则按原始格式解码 - if (pick.length === 0) { - const d1 = decode('xywh') - const d2 = decode('xyxy') - pick = d2.length > d1.length ? d2 : d1 - } - detections.push(...pick) - // 执行NMS(支持按类别) - const classWise = opts?.classWise ?? true - let kept: Array<{x: number, y: number, w: number, h: number, conf: number, class: number}> = [] - if (classWise) { - const byClass: Record = {} - for (const d of detections) { - (byClass[d.class] ||= []).push(d) - } - for (const k in byClass) { - const group = byClass[k] - const idxs = this.nms(group, nmsThreshold) - kept.push(...idxs.map(i => group[i])) - } - } else { - const idxs = this.nms(detections, nmsThreshold) - kept = idxs.map(i => detections[i]) - } - - // 置信度排序并限制最大数量 - kept.sort((a, b) => b.conf - a.conf) - const limited = kept.slice(0, opts?.maxDetections ?? 100) - - // 构建最终结果 - return limited.map(det => { - const className = this.classNames[det.class] || `class_${det.class}` - return { - class_name: className, - confidence: det.conf, - bbox: { - x: Math.max(0, det.x), - y: Math.max(0, det.y), - width: Math.min(det.w, originalWidth - det.x), - height: Math.min(det.h, originalHeight - det.y) - } - } - }) - } - - /** - * 在图像上绘制检测框 - */ - private drawDetections(canvas: HTMLCanvasElement, detections: Array<{class_name: string, confidence: number, bbox: {x: number, y: number, width: number, height: number}}>): void { - const ctx = canvas.getContext('2d') - if (!ctx) return - - // 为每个类别分配颜色 - const colors: {[key: string]: string} = {} - const colorPalette = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8', '#F7DC6F', '#BB8FCE', '#85C1E2'] - - detections.forEach((det, idx) => { - if (!colors[det.class_name]) { - colors[det.class_name] = colorPalette[idx % colorPalette.length] - } - }) - - detections.forEach(det => { - const { x, y, width, height } = det.bbox - const color = colors[det.class_name] - - // 绘制边界框 - ctx.strokeStyle = color - ctx.lineWidth = 2 - ctx.strokeRect(x, y, width, height) - - // 绘制标签背景 - const label = `${det.class_name} ${(det.confidence * 100).toFixed(1)}%` - ctx.font = '14px Arial' - const textMetrics = ctx.measureText(label) - const textWidth = textMetrics.width - const textHeight = 20 - - ctx.fillStyle = color - ctx.fillRect(x, y - textHeight, textWidth + 10, textHeight) - - // 绘制标签文字 - ctx.fillStyle = '#FFFFFF' - ctx.fillText(label, x + 5, y - 5) - }) - } - - /** - * 执行检测 - * @param image 图像元素(Image, Video, 或 Canvas) - * @param confidenceThreshold 置信度阈值 - * @param nmsThreshold NMS阈值 - */ - async detect( - image: HTMLImageElement | HTMLVideoElement | HTMLCanvasElement, - confidenceThreshold: number = 0.25, - nmsThreshold: number = 0.7 - ): Promise { - if (!this.session) { - throw new Error('模型未加载,请先调用 loadModel()') - } - - const startTime = performance.now() - - try { - // 获取原始图像尺寸 - const originalWidth = image instanceof HTMLVideoElement ? image.videoWidth : image.width - const originalHeight = image instanceof HTMLVideoElement ? image.videoHeight : image.height - - // 预处理图像(letterbox) - const prep = this.preprocessImage(image) - - // 创建输入tensor [1, 3, H, W] - const inputTensor = new ort.Tensor('float32', prep.input, [1, 3, this.inputShape[1], this.inputShape[0]]) - - // 执行推理 - const inputName = this.session.inputNames[0] - const feeds = { [inputName]: inputTensor } - const results = await this.session.run(feeds) - - // 获取输出 - const outputName = this.session.outputNames[0] - const output = results[outputName] - - // 后处理 - const detections = this.postprocess( - output, - { ratio: prep.ratio, padX: prep.padX, padY: prep.padY, srcW: originalWidth, srcH: originalHeight }, - confidenceThreshold, - nmsThreshold, - { maxDetections: 100, minBoxArea: originalWidth * originalHeight * 0.0001, classWise: true } - ) - - // 计算统计信息 - const objectCount = detections.length - const detectedCategories = [...new Set(detections.map(d => d.class_name))] - const confidenceScores = detections.map(d => d.confidence) - const avgConfidence = confidenceScores.length > 0 - ? confidenceScores.reduce((a, b) => a + b, 0) / confidenceScores.length - : 0 - - // 绘制检测结果 - const canvas = document.createElement('canvas') - canvas.width = originalWidth - canvas.height = originalHeight - const ctx = canvas.getContext('2d') - if (ctx) { - ctx.drawImage(image, 0, 0, originalWidth, originalHeight) - this.drawDetections(canvas, detections) - } - - // 转换为base64(降低质量,减少内存与传输开销) - const annotatedImage = canvas.toDataURL('image/jpeg', 0.4) - - const processingTime = (performance.now() - startTime) / 1000 - - return { - detections: detections.map(d => ({ - class_name: d.class_name, - confidence: d.confidence, - bbox: d.bbox - })), - object_count: objectCount, - detected_categories: detectedCategories, - confidence_scores: confidenceScores, - avg_confidence: avgConfidence, - annotated_image: annotatedImage, - processing_time: processingTime - } - } catch (error) { - console.error('❌ 检测失败:', error) - // 若 GPU 后端不支持某些算子,自动回退到 WASM 并重试一次 - const msg = String((error as any)?.message || error) - const needFallback = /GatherND|Unsupported data type|JSF Kernel|ExecuteKernel|WebGPU|WebGL|worker not ready/i.test(msg) - if (needFallback && this.currentEP !== 'wasm') { - try { - console.warn('⚠️ 检测算子不被 GPU 支持,自动回退到 WASM 并重试一次。') - // 强制全局与本次调用走 WASM - localStorage.setItem('ort_force_wasm','1') - await this.loadModel(this.modelPath, this.classNames, 'wasm') - // 强制使用 wasm - // @ts-ignore - ort.env.wasm.proxy = true - this.currentEP = 'wasm' - return await this.detect(image, confidenceThreshold, nmsThreshold) - } catch (e2) { - console.error('❌ 回退到 WASM 后仍失败:', e2) - } - } - // 如果已是 wasm,但报 worker not ready,再降级为单线程重建 session - if (/worker not ready/i.test(msg) && this.currentEP === 'wasm') { - try { - // @ts-ignore - ort.env.wasm.numThreads = 1 - await this.loadModel(this.modelPath, this.classNames) - return await this.detect(image, confidenceThreshold, nmsThreshold) - } catch (e3) { - console.error('❌ 降级单线程后仍失败:', e3) - } - } - throw new Error(`检测失败: ${error instanceof Error ? error.message : '未知错误'}`) - } - } - - /** - * 释放模型资源 - */ - dispose(): void { - if (this.session) { - // ONNX Runtime会自动管理资源,但我们可以清理引用 - this.session = null - this.modelPath = '' - console.log('🗑️ 模型资源已释放') - } - } -} - -// 导出单例 -export const yoloDetector = new YOLODetector() diff --git a/hertz_server_django_ui/src/views/Home.vue b/hertz_server_django_ui/src/views/Home.vue deleted file mode 100644 index ec7a029..0000000 --- a/hertz_server_django_ui/src/views/Home.vue +++ /dev/null @@ -1,505 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/Login.vue b/hertz_server_django_ui/src/views/Login.vue deleted file mode 100644 index b9bac72..0000000 --- a/hertz_server_django_ui/src/views/Login.vue +++ /dev/null @@ -1,464 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/ModuleSetup.vue b/hertz_server_django_ui/src/views/ModuleSetup.vue deleted file mode 100644 index a540684..0000000 --- a/hertz_server_django_ui/src/views/ModuleSetup.vue +++ /dev/null @@ -1,149 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/NotFound.vue b/hertz_server_django_ui/src/views/NotFound.vue deleted file mode 100644 index 3cc9b5f..0000000 --- a/hertz_server_django_ui/src/views/NotFound.vue +++ /dev/null @@ -1,65 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/AlertLevelManagement.vue b/hertz_server_django_ui/src/views/admin_page/AlertLevelManagement.vue deleted file mode 100644 index cc2bfc5..0000000 --- a/hertz_server_django_ui/src/views/admin_page/AlertLevelManagement.vue +++ /dev/null @@ -1,1046 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/AlertProcessingCenter.vue b/hertz_server_django_ui/src/views/admin_page/AlertProcessingCenter.vue deleted file mode 100644 index a5d8dc3..0000000 --- a/hertz_server_django_ui/src/views/admin_page/AlertProcessingCenter.vue +++ /dev/null @@ -1,1057 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/ArticleManagement.vue b/hertz_server_django_ui/src/views/admin_page/ArticleManagement.vue deleted file mode 100644 index 4c34ffa..0000000 --- a/hertz_server_django_ui/src/views/admin_page/ArticleManagement.vue +++ /dev/null @@ -1,1369 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/Dashboard.vue b/hertz_server_django_ui/src/views/admin_page/Dashboard.vue deleted file mode 100644 index 4a5f1a5..0000000 --- a/hertz_server_django_ui/src/views/admin_page/Dashboard.vue +++ /dev/null @@ -1,1536 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/DatasetManagement.vue b/hertz_server_django_ui/src/views/admin_page/DatasetManagement.vue deleted file mode 100644 index 1e1c1cb..0000000 --- a/hertz_server_django_ui/src/views/admin_page/DatasetManagement.vue +++ /dev/null @@ -1,1404 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/DepartmentManagement.vue b/hertz_server_django_ui/src/views/admin_page/DepartmentManagement.vue deleted file mode 100644 index da2caa9..0000000 --- a/hertz_server_django_ui/src/views/admin_page/DepartmentManagement.vue +++ /dev/null @@ -1,989 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/DetectionHistoryManagement.vue b/hertz_server_django_ui/src/views/admin_page/DetectionHistoryManagement.vue deleted file mode 100644 index c8deded..0000000 --- a/hertz_server_django_ui/src/views/admin_page/DetectionHistoryManagement.vue +++ /dev/null @@ -1,1016 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/LogManagement.vue b/hertz_server_django_ui/src/views/admin_page/LogManagement.vue deleted file mode 100644 index 060d922..0000000 --- a/hertz_server_django_ui/src/views/admin_page/LogManagement.vue +++ /dev/null @@ -1,1221 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/MenuManagement.vue b/hertz_server_django_ui/src/views/admin_page/MenuManagement.vue deleted file mode 100644 index 14b792c..0000000 --- a/hertz_server_django_ui/src/views/admin_page/MenuManagement.vue +++ /dev/null @@ -1,1199 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/ModelManagement.vue b/hertz_server_django_ui/src/views/admin_page/ModelManagement.vue deleted file mode 100644 index cb12311..0000000 --- a/hertz_server_django_ui/src/views/admin_page/ModelManagement.vue +++ /dev/null @@ -1,1832 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/NotificationManagement.vue b/hertz_server_django_ui/src/views/admin_page/NotificationManagement.vue deleted file mode 100644 index d5685ef..0000000 --- a/hertz_server_django_ui/src/views/admin_page/NotificationManagement.vue +++ /dev/null @@ -1,1526 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/Role.vue b/hertz_server_django_ui/src/views/admin_page/Role.vue deleted file mode 100644 index ac1a972..0000000 --- a/hertz_server_django_ui/src/views/admin_page/Role.vue +++ /dev/null @@ -1,1679 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/UserManagement.vue b/hertz_server_django_ui/src/views/admin_page/UserManagement.vue deleted file mode 100644 index c2d0257..0000000 --- a/hertz_server_django_ui/src/views/admin_page/UserManagement.vue +++ /dev/null @@ -1,1571 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/admin_page/YoloTrainManagement.vue b/hertz_server_django_ui/src/views/admin_page/YoloTrainManagement.vue deleted file mode 100644 index da83189..0000000 --- a/hertz_server_django_ui/src/views/admin_page/YoloTrainManagement.vue +++ /dev/null @@ -1,1305 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/admin_page/index.vue b/hertz_server_django_ui/src/views/admin_page/index.vue deleted file mode 100644 index 6e76d45..0000000 --- a/hertz_server_django_ui/src/views/admin_page/index.vue +++ /dev/null @@ -1,1689 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/register.vue b/hertz_server_django_ui/src/views/register.vue deleted file mode 100644 index 5d4ec38..0000000 --- a/hertz_server_django_ui/src/views/register.vue +++ /dev/null @@ -1,383 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/user_pages/AiChat.vue b/hertz_server_django_ui/src/views/user_pages/AiChat.vue deleted file mode 100644 index 468da18..0000000 --- a/hertz_server_django_ui/src/views/user_pages/AiChat.vue +++ /dev/null @@ -1,1639 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/AlertCenter.vue b/hertz_server_django_ui/src/views/user_pages/AlertCenter.vue deleted file mode 100644 index ae78e5c..0000000 --- a/hertz_server_django_ui/src/views/user_pages/AlertCenter.vue +++ /dev/null @@ -1,1102 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/user_pages/ArticleCenter.vue b/hertz_server_django_ui/src/views/user_pages/ArticleCenter.vue deleted file mode 100644 index e0e4725..0000000 --- a/hertz_server_django_ui/src/views/user_pages/ArticleCenter.vue +++ /dev/null @@ -1,625 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/ArticleDetail.vue b/hertz_server_django_ui/src/views/user_pages/ArticleDetail.vue deleted file mode 100644 index ae63525..0000000 --- a/hertz_server_django_ui/src/views/user_pages/ArticleDetail.vue +++ /dev/null @@ -1,671 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/DetectionHistory.vue b/hertz_server_django_ui/src/views/user_pages/DetectionHistory.vue deleted file mode 100644 index 6c545ca..0000000 --- a/hertz_server_django_ui/src/views/user_pages/DetectionHistory.vue +++ /dev/null @@ -1,2554 +0,0 @@ - - - - - diff --git a/hertz_server_django_ui/src/views/user_pages/Documents.vue b/hertz_server_django_ui/src/views/user_pages/Documents.vue deleted file mode 100644 index 35d9648..0000000 --- a/hertz_server_django_ui/src/views/user_pages/Documents.vue +++ /dev/null @@ -1,251 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/KbCenter.vue b/hertz_server_django_ui/src/views/user_pages/KbCenter.vue deleted file mode 100644 index 5696f5c..0000000 --- a/hertz_server_django_ui/src/views/user_pages/KbCenter.vue +++ /dev/null @@ -1,2499 +0,0 @@ - - - - - - - diff --git a/hertz_server_django_ui/src/views/user_pages/LiveDetection.vue b/hertz_server_django_ui/src/views/user_pages/LiveDetection.vue deleted file mode 100644 index faa652b..0000000 --- a/hertz_server_django_ui/src/views/user_pages/LiveDetection.vue +++ /dev/null @@ -1,2185 +0,0 @@ - - - - - - diff --git a/hertz_server_django_ui/src/views/user_pages/Messages.vue b/hertz_server_django_ui/src/views/user_pages/Messages.vue deleted file mode 100644 index 301f742..0000000 --- a/hertz_server_django_ui/src/views/user_pages/Messages.vue +++ /dev/null @@ -1,318 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/NoticeCenter.vue b/hertz_server_django_ui/src/views/user_pages/NoticeCenter.vue deleted file mode 100644 index 77eabd2..0000000 --- a/hertz_server_django_ui/src/views/user_pages/NoticeCenter.vue +++ /dev/null @@ -1,842 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/Profile.vue b/hertz_server_django_ui/src/views/user_pages/Profile.vue deleted file mode 100644 index f60daaf..0000000 --- a/hertz_server_django_ui/src/views/user_pages/Profile.vue +++ /dev/null @@ -1,519 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/SystemMonitor.vue b/hertz_server_django_ui/src/views/user_pages/SystemMonitor.vue deleted file mode 100644 index 254ef96..0000000 --- a/hertz_server_django_ui/src/views/user_pages/SystemMonitor.vue +++ /dev/null @@ -1,835 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/views/user_pages/YoloDetection.vue b/hertz_server_django_ui/src/views/user_pages/YoloDetection.vue deleted file mode 100644 index 4b3b807..0000000 --- a/hertz_server_django_ui/src/views/user_pages/YoloDetection.vue +++ /dev/null @@ -1,3620 +0,0 @@ - - - - - - - diff --git a/hertz_server_django_ui/src/views/user_pages/index.vue b/hertz_server_django_ui/src/views/user_pages/index.vue deleted file mode 100644 index 94e9b32..0000000 --- a/hertz_server_django_ui/src/views/user_pages/index.vue +++ /dev/null @@ -1,7122 +0,0 @@ - - - - - \ No newline at end of file diff --git a/hertz_server_django_ui/src/vite-env.d.ts b/hertz_server_django_ui/src/vite-env.d.ts deleted file mode 100644 index 11f02fe..0000000 --- a/hertz_server_django_ui/src/vite-env.d.ts +++ /dev/null @@ -1 +0,0 @@ -/// diff --git a/hertz_server_django_ui/tsconfig.app.json b/hertz_server_django_ui/tsconfig.app.json deleted file mode 100644 index 49f420b..0000000 --- a/hertz_server_django_ui/tsconfig.app.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "extends": "@vue/tsconfig/tsconfig.dom.json", - "compilerOptions": { - "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo", - "baseUrl": ".", - "paths": { - "@/*": ["src/*"] - }, - - /* Linting - 放宽限制以减少WebStorm警告 */ - "strict": false, - "noUnusedLocals": false, - "noUnusedParameters": false, - "erasableSyntaxOnly": false, - "noFallthroughCasesInSwitch": false, - "noUncheckedSideEffectImports": false, - "exactOptionalPropertyTypes": false, - "noImplicitReturns": false, - "noImplicitOverride": false - }, - "include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"] -} diff --git a/hertz_server_django_ui/tsconfig.json b/hertz_server_django_ui/tsconfig.json deleted file mode 100644 index 1ffef60..0000000 --- a/hertz_server_django_ui/tsconfig.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "files": [], - "references": [ - { "path": "./tsconfig.app.json" }, - { "path": "./tsconfig.node.json" } - ] -} diff --git a/hertz_server_django_ui/tsconfig.node.json b/hertz_server_django_ui/tsconfig.node.json deleted file mode 100644 index 9c2440f..0000000 --- a/hertz_server_django_ui/tsconfig.node.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "compilerOptions": { - "tsBuildInfoFile": "./node_modules/.tmp/tsconfig.node.tsbuildinfo", - "target": "ES2023", - "lib": ["ES2023"], - "module": "ESNext", - "skipLibCheck": true, - - /* Bundler mode */ - "moduleResolution": "bundler", - "allowImportingTsExtensions": true, - "verbatimModuleSyntax": true, - "moduleDetection": "force", - "noEmit": true, - - /* Linting - 放宽限制以减少WebStorm警告 */ - "strict": false, - "noUnusedLocals": false, - "noUnusedParameters": false, - "erasableSyntaxOnly": false, - "noFallthroughCasesInSwitch": false, - "noUncheckedSideEffectImports": false - }, - "include": ["vite.config.ts"] -} diff --git a/hertz_server_django_ui/vite.config.ts b/hertz_server_django_ui/vite.config.ts deleted file mode 100644 index 27710db..0000000 --- a/hertz_server_django_ui/vite.config.ts +++ /dev/null @@ -1,333 +0,0 @@ -import { defineConfig, type Plugin, loadEnv } from 'vite' -import vue from '@vitejs/plugin-vue' -import { resolve } from 'path' -import fs from 'fs' -import Components from 'unplugin-vue-components/vite' -import { AntDesignVueResolver } from 'unplugin-vue-components/resolvers' - -// https://vite.dev/config/ -// 生成 public/models/manifest.json,自动列举 .onnx 文件 -function modelsManifestPlugin(): Plugin { - const writeManifest = () => { - try { - const modelsDir = resolve(__dirname, 'public/models') - if (!fs.existsSync(modelsDir)) return - const files = fs - .readdirSync(modelsDir) - .filter((f) => f.toLowerCase().endsWith('.onnx')) - const manifestPath = resolve(modelsDir, 'manifest.json') - fs.writeFileSync(manifestPath, JSON.stringify(files, null, 2)) - console.log(`📦 models manifest updated (${files.length}):`, files) - } catch (e) { - console.warn('⚠️ update models manifest failed:', (e as any)?.message) - } - } - - return { - name: 'models-manifest', - apply: 'serve', - configureServer(server) { - writeManifest() - const dir = resolve(__dirname, 'public/models') - try { - if (fs.existsSync(dir)) { - fs.watch(dir, { persistent: true }, (_event, filename) => { - if (!filename) return - if (filename.toLowerCase().endsWith('.onnx')) writeManifest() - }) - } - } catch {} - }, - buildStart() { - writeManifest() - }, - closeBundle() { - writeManifest() - }, - } -} - -export default defineConfig(({ mode }) => { - const env = loadEnv(mode, process.cwd(), '') - const apiBaseUrl = env.VITE_API_BASE_URL || 'http://localhost:3000' - const backendOrigin = apiBaseUrl.replace(/\/+$/, '') - - return { - plugins: [ - vue(), - modelsManifestPlugin(), - Components({ - resolvers: [ - AntDesignVueResolver({ - importStyle: false, // css in js - }), - ], - }), - ], - resolve: { - alias: { - '@': resolve(__dirname, 'src'), - '~': resolve(__dirname, 'src'), - }, - }, - server: { - host: '0.0.0.0', // 新增:允许所有网络接口访问 - port: 3001, // 明确设置为3001端口 - open: true, - cors: true, - proxy: { - // RSS新闻代理转发到百度新闻(需要放在/api之前,优先匹配) - '/api/rss': { - target: 'https://news.baidu.com', - changeOrigin: true, - secure: true, - timeout: 10000, // 设置10秒超时 - rewrite: (path) => { - // 百度新闻RSS格式: /n?cmd=1&class=类别&tn=rss - // 支持多种RSS路径 - if (path.includes('/world')) { - return '/n?cmd=1&class=internet&tn=rss' // 国际新闻 - } else if (path.includes('/tech')) { - return '/n?cmd=1&class=technic&tn=rss' // 科技新闻 - } else if (path.includes('/domestic')) { - return '/n?cmd=1&class=civilnews&tn=rss' // 国内新闻 - } else if (path.includes('/finance')) { - return '/n?cmd=1&class=finance&tn=rss' // 财经新闻 - } - // 默认使用国内新闻 - return '/n?cmd=1&class=civilnews&tn=rss' - }, - configure: (proxy, options) => { - proxy.on('proxyReq', (proxyReq, req, res) => { - // 添加必要的请求头,模拟浏览器请求 - proxyReq.setHeader('User-Agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36') - proxyReq.setHeader('Accept', 'application/xml, text/xml, */*') - proxyReq.setHeader('Accept-Language', 'zh-CN,zh;q=0.9,en;q=0.8') - proxyReq.setHeader('Referer', 'https://news.baidu.com/') - proxyReq.setHeader('Host', 'news.baidu.com') - // 移除Origin,避免CORS问题 - proxyReq.removeHeader('Origin') - if (process.env.NODE_ENV === 'development') { - console.log(`📰 RSS代理请求: ${req.method} ${req.url} -> ${proxyReq.path}`) - } - }) - - proxy.on('proxyRes', (proxyRes, req, res) => { - // 添加CORS头部 - res.setHeader('Access-Control-Allow-Origin', '*') - res.setHeader('Access-Control-Allow-Methods', 'GET, OPTIONS') - res.setHeader('Access-Control-Allow-Headers', 'Content-Type') - res.setHeader('Access-Control-Expose-Headers', 'Content-Type') - - // 确保Content-Type正确 - if (proxyRes.headers['content-type']) { - res.setHeader('Content-Type', proxyRes.headers['content-type']) - } else { - res.setHeader('Content-Type', 'application/xml; charset=utf-8') - } - - if (process.env.NODE_ENV === 'development') { - console.log(`✅ RSS响应: ${proxyRes.statusCode} ${req.url}`) - } - }) - - proxy.on('error', (err, req, res) => { - console.error('❌ RSS代理错误:', err.message) - }) - }, - }, - // 翻译API代理转发到腾讯翻译(需要放在/api之前,优先匹配) - '/api/translate': { - target: 'https://fanyi.qq.com', - changeOrigin: true, - secure: true, - timeout: 10000, // 设置10秒超时 - rewrite: (path) => { - // 腾讯翻译接口路径是 /api/translate,需要保留所有查询参数 - const pathWithoutPrefix = path.replace(/^\/api\/translate/, '/api/translate') - if (process.env.NODE_ENV === 'development') { - console.log('翻译代理路径重写:', path, '->', pathWithoutPrefix) - } - return pathWithoutPrefix - }, - configure: (proxy, options) => { - proxy.on('proxyReq', (proxyReq, req, res) => { - // 添加必要的请求头,模拟浏览器请求 - proxyReq.setHeader('User-Agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36') - proxyReq.setHeader('Accept', 'application/json, text/plain, */*') - proxyReq.setHeader('Accept-Language', 'zh-CN,zh;q=0.9,en;q=0.8') - proxyReq.setHeader('Referer', 'https://fanyi.qq.com/') - proxyReq.setHeader('Content-Type', 'application/json; charset=UTF-8') - // 移除Origin,避免CORS问题 - if (proxyReq.getHeader('Origin')) { - proxyReq.removeHeader('Origin') - } - if (process.env.NODE_ENV === 'development') { - console.log(`🌐 翻译代理请求: ${req.method} ${req.url} -> ${proxyReq.path}`) - console.log(`🌐 代理目标: https://fanyi.qq.com${proxyReq.path}`) - } - }) - - proxy.on('proxyRes', (proxyRes, req, res) => { - // 添加CORS头部 - res.setHeader('Access-Control-Allow-Origin', '*') - res.setHeader('Access-Control-Allow-Methods', 'GET, POST, OPTIONS') - res.setHeader('Access-Control-Allow-Headers', 'Content-Type') - res.setHeader('Access-Control-Expose-Headers', 'Content-Type') - - // 确保Content-Type正确 - if (proxyRes.headers['content-type']) { - res.setHeader('Content-Type', proxyRes.headers['content-type']) - } else { - res.setHeader('Content-Type', 'application/json; charset=utf-8') - } - - if (process.env.NODE_ENV === 'development') { - console.log(`✅ 翻译响应: ${proxyRes.statusCode} ${req.url}`) - // 如果是错误状态码,记录详细信息 - if (proxyRes.statusCode >= 400) { - console.error(`❌ 翻译API错误: ${proxyRes.statusCode} ${req.url}`) - } - } - }) - - proxy.on('error', (err, req, res) => { - console.error('❌ 翻译代理错误:', err.message) - }) - }, - }, - // 天气API代理转发到中国气象局(需要放在/api之前,优先匹配) - '/api/weather': { - target: 'https://weather.cma.cn', - changeOrigin: true, - secure: true, - rewrite: (path) => path.replace(/^\/api\/weather/, '/api/weather'), - configure: (proxy, options) => { - proxy.on('proxyReq', (proxyReq, req, res) => { - // 添加必要的请求头,模拟浏览器请求 - proxyReq.setHeader('User-Agent', 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36') - proxyReq.setHeader('Accept', 'application/json, text/plain, */*') - proxyReq.setHeader('Accept-Language', 'zh-CN,zh;q=0.9,en;q=0.8') - proxyReq.setHeader('Referer', 'https://weather.cma.cn/') - proxyReq.setHeader('Origin', 'https://weather.cma.cn') - proxyReq.setHeader('X-Proxy-By', 'Vite-Dev-Server') - if (process.env.NODE_ENV === 'development') { - console.log(`🌤️ 天气API代理: ${req.method} ${req.url}`) - } - }) - - proxy.on('proxyRes', (proxyRes, req, res) => { - // 添加CORS头部 - res.setHeader('Access-Control-Allow-Origin', '*') - res.setHeader('Access-Control-Allow-Methods', 'GET, OPTIONS') - res.setHeader('Access-Control-Allow-Headers', 'Content-Type') - if (process.env.NODE_ENV === 'development') { - console.log(`✅ 天气API响应: ${proxyRes.statusCode} ${req.url}`) - } - }) - - proxy.on('error', (err, req, res) => { - console.error('❌ 天气API代理错误:', err.message) - }) - }, - }, - // API代理转发到后端服务器 - '/api': { - target: backendOrigin, - changeOrigin: true, - secure: false, - rewrite: (path) => path.replace(/^\/api/, '/api'), - // 优化Network面板显示 - // 保持原始头部信息 - preserveHeaderKeyCase: true, - // 添加CORS头部,改善Network面板显示 - headers: { - 'Access-Control-Allow-Origin': '*', - 'Access-Control-Allow-Methods': 'GET,PUT,POST,DELETE,PATCH,OPTIONS', - 'Access-Control-Allow-Headers': 'Content-Type, Authorization, Content-Length, X-Requested-With', - }, - configure: (proxy, options) => { - // 简化代理日志 - proxy.on('proxyReq', (proxyReq, req, res) => { - // 添加标识头部,帮助Network面板识别 - proxyReq.setHeader('X-Proxy-By', 'Vite-Dev-Server') - if (process.env.NODE_ENV === 'development') { - console.log(`🔄 代理: ${req.method} ${req.url}`) - } - }) - - proxy.on('proxyRes', (proxyRes, req, res) => { - // 添加响应头部,改善Network面板显示 - res.setHeader('X-Proxy-By', 'Vite-Dev-Server') - res.setHeader('Access-Control-Allow-Origin', '*') - if (process.env.NODE_ENV === 'development') { - console.log(`✅ 响应: ${proxyRes.statusCode} ${req.url}`) - } - }) - - proxy.on('error', (err, req, res) => { - console.error('❌ 代理错误:', err.message) - }) - }, - }, - // 媒体文件代理转发到后端服务器 - '/media': { - target: backendOrigin, - changeOrigin: true, - secure: false, - rewrite: (path) => path.replace(/^\/media/, '/media'), - // 优化Network面板显示 - // 保持原始头部信息 - preserveHeaderKeyCase: true, - // 添加CORS头部,改善Network面板显示 - headers: { - 'Access-Control-Allow-Origin': '*', - 'Access-Control-Allow-Methods': 'GET,PUT,POST,DELETE,PATCH,OPTIONS', - 'Access-Control-Allow-Headers': 'Content-Type, Authorization, Content-Length, X-Requested-With', - }, - configure: (proxy, options) => { - // 简化代理日志 - proxy.on('proxyReq', (proxyReq, req, res) => { - // 添加标识头部,帮助Network面板识别 - proxyReq.setHeader('X-Proxy-By', 'Vite-Dev-Server') - if (process.env.NODE_ENV === 'development') { - console.log(`🔄 媒体代理: ${req.method} ${req.url}`) - } - }) - - proxy.on('proxyRes', (proxyRes, req, res) => { - // 添加响应头部,改善Network面板显示 - res.setHeader('X-Proxy-By', 'Vite-Dev-Server') - res.setHeader('Access-Control-Allow-Origin', '*') - if (process.env.NODE_ENV === 'development') { - console.log(`✅ 媒体响应: ${proxyRes.statusCode} ${req.url}`) - } - }) - - proxy.on('error', (err, req, res) => { - console.error('❌ 媒体代理错误:', err.message) - }) - }, - }, - }, - }, - define: { - // 环境变量定义,确保在没有.env文件时也能正常工作 - __VITE_API_BASE_URL__: JSON.stringify(`${backendOrigin}/api`), - __VITE_APP_TITLE__: JSON.stringify('Hertz Admin'), - __VITE_APP_VERSION__: JSON.stringify('1.0.0'), - }, - build: { - sourcemap: true, - rollupOptions: { - output: { - manualChunks: { - vue: ['vue', 'vue-router', 'pinia'], - antd: ['ant-design-vue'], - utils: ['axios', 'echarts'], - }, - }, - }, - }, - } -}) diff --git a/hertz_server_django_ui/修改操作指南.md b/hertz_server_django_ui/修改操作指南.md deleted file mode 100644 index 06ecc5f..0000000 --- a/hertz_server_django_ui/修改操作指南.md +++ /dev/null @@ -1,422 +0,0 @@ -# 前端样式 / 布局 / UI 修改操作指南 - -> 面向二次开发前端,告诉你:**改样式 / 改布局 / 改 UI 具体要动哪些文件、怎么改**。 -> -> 项目技术栈:Vite + Vue 3 + TypeScript + Ant Design Vue + Pinia + Vue Router + SCSS - ---- - -## 🧭 速查表 - -- **改全局颜色 / 按钮 / 弹窗风格**:看第 1 章「整体样式体系总览」(`src/styles/index.scss` + `src/styles/variables.scss`) -- **改管理端整体布局(侧边栏、头部、内容区排版)**:看 2.1「管理端整体布局」(`src/views/admin_page/index.vue`) -- **改用户端整体布局(顶部导航 + 内容容器)」**:看 2.2「用户端整体布局」(`src/views/user_pages/index.vue`) -- **改 YOLO 检测排版 / 三种布局 / 卡片样式**:看第 3 章「YOLO 检测页面修改指南」(`src/views/user_pages/YoloDetection.vue`) -- **改 AI 助手聊天布局**:看第 4 章「修改 AI 助手页面」(`src/views/user_pages/AiChat.vue`) - ---- - -## 1. 整体样式体系总览 - -### 1.1 全局样式入口 - -- 入口文件:`src/styles/index.scss` -- 在 `src/main.ts` 中全局引入: - - `import './styles/index.scss'` -- 主要职责: - - 重置 margin / padding / box-sizing - - 全局字体、`html, body, #app` 基础样式 - - 自定义 `.btn` / `.card` 等通用类 - - 全局 Ant Design Vue 主题风格覆盖(如 `.ant-modal`, `.ant-btn` 等) - -**如果你要改全局的按钮、弹窗、表单、输入框等基础风格:** - -1. 打开 `src/styles/index.scss` -2. 找对应的选择器: - - 按钮:`.ant-btn` 下的几种状态(`&.ant-btn-default` / `&.ant-btn-primary` / `&.ant-btn-dangerous` 等) - - 弹窗:`.ant-modal` 内的 `.ant-modal-content` / `.ant-modal-header` / `.ant-modal-footer` - - 输入/选择等:`.ant-input`, `.ant-select-selector`, `.ant-input-number`, `.ant-picker` 等 -3. 直接在这里调整颜色、圆角、阴影、间距。 -4. 样式会作用于所有页面,无需在每个 `.vue` 里重复写。 - -> 建议:全局 Design System 统一改在这里,不要在业务页面里到处改 AntD 默认样式。 - -### 1.2 变量和混合(主题基础) - -- 文件:`src/styles/variables.scss` -- 主要内容: - - 颜色:`$primary-color`、`$success-color`、`$gray-xxx` 等 - - 间距:`$spacing-1 ~ $spacing-20` - - 圆角:`$radius-md`、`$radius-lg` 等 - - 阴影:`$shadow-sm` / `$shadow-md` / `$shadow-lg` - - 常用 mixin:`@mixin card-style`、`@mixin button-style` 等 - -**改全局配色 / 圆角 / 阴影的操作方式:** - -1. 打开 `src/styles/variables.scss` -2. 修改对应变量: - - 主色:`$primary-color` / `$primary-light` / `$primary-dark` - - 背景:`$bg-primary` / `$bg-secondary` - - 阴影:`$shadow-md` / `$shadow-lg` -3. 不需要修改业务页面,使用这些变量的地方会统一生效。 - -> 如果要在页面里复用统一卡片/按钮样式,可以直接: -> -> ```scss -> .my-card { -> @include card-style; -> } -> -> .my-primary-button { -> @include button-style($primary-color, #fff); -> } -> ``` - -### 1.3 主题 store 与 CSS 变量 - -- 文件:`src/stores/hertz_theme.ts` -- 作用: - - 定义 `ThemeConfig`(导航栏背景、页面背景、卡片背景、主色、文字颜色) - - 使用 `document.documentElement.style.setProperty` 写入 CSS 变量: - - `--theme-header-bg`, `--theme-page-bg`, `--theme-card-bg`, `--theme-primary`, `--theme-text-primary` 等 -- 使用方式: - - 在页面/组件的 SCSS 中,通过 `var(--theme-primary)` 等变量引用主题色: - - 示例:`color: var(--theme-text-primary, #1e293b);` - -**修改主题默认值**: - -1. 打开 `src/stores/hertz_theme.ts` -2. 修改 `defaultTheme` 对象里的颜色值即可: - - 如:`primaryColor: '#FF4D4F'` 改成你的品牌色 -3. 调用 `themeStore.loadTheme()` 时会自动应用到全局。 - -**页面内如何用这些主题变量?** - -- 在 SCSS 中使用: - -```scss -.some-block { - background: var(--theme-card-bg, #fff); - color: var(--theme-text-primary, #1e293b); - border-color: var(--theme-card-border, #e5e7eb); -} -``` - ---- - -## 2. 布局结构:管理端 / 用户端 - -### 2.1 管理端整体布局 - -- 入口布局:`src/views/admin_page/index.vue` -- 结构: - - 外层 `.admin-layout` - - 使用 `a-layout` + `a-layout-sider` + `a-layout-header` + `a-layout-content` + `a-layout-footer` - - 侧边菜单:`a-layout-sider` 内的 `a-menu`,使用 `admin_menu.ts` 生成菜单项 - -**修改管理端整体布局方式(比如侧边栏宽度、顶部高度):** - -1. 打开 `src/views/admin_page/index.vue` -2. 找到模板部分: - - 侧边栏:`` - - 顶部:`` - - 内容:`` -3. 在同文件底部的 `""" # Hide main menu style - - # Main title of streamlit application - main_title_cfg = """

Ultralytics YOLO Streamlit Application

""" - - # Subtitle of streamlit application - sub_title_cfg = """
Experience real-time object detection on your webcam, videos, and images - with the power of Ultralytics YOLO! 🚀
""" - - # Set html page configuration and append custom HTML - self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide") - self.st.markdown(menu_style_cfg, unsafe_allow_html=True) - self.st.markdown(main_title_cfg, unsafe_allow_html=True) - self.st.markdown(sub_title_cfg, unsafe_allow_html=True) - - def sidebar(self) -> None: - """Configure the Streamlit sidebar for model and inference settings.""" - with self.st.sidebar: # Add Ultralytics LOGO - logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg" - self.st.image(logo, width=250) - - self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu - self.source = self.st.sidebar.selectbox( - "Source", - ("webcam", "video", "image"), - ) # Add source selection dropdown - if self.source in ["webcam", "video"]: - self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking - self.conf = float( - self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01) - ) # Slider for confidence - self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold - - if self.source != "image": # Only create columns for video/webcam - col1, col2 = self.st.columns(2) # Create two columns for displaying frames - self.org_frame = col1.empty() # Container for original frame - self.ann_frame = col2.empty() # Container for annotated frame - - def source_upload(self) -> None: - """Handle video file uploads through the Streamlit interface.""" - from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import - - self.vid_file_name = "" - if self.source == "video": - vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS) - if vid_file is not None: - g = io.BytesIO(vid_file.read()) # BytesIO Object - with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes - out.write(g.read()) # Read bytes into file - self.vid_file_name = "ultralytics.mp4" - elif self.source == "webcam": - self.vid_file_name = 0 # Use webcam index 0 - elif self.source == "image": - import tempfile # scope import - - imgfiles = self.st.sidebar.file_uploader("Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True) - if imgfiles: - for imgfile in imgfiles: # Save each uploaded image to a temporary file - with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf: - tf.write(imgfile.read()) - self.img_file_names.append({"path": tf.name, "name": imgfile.name}) - - def configure(self) -> None: - """Configure the model and load selected classes for inference.""" - # Add dropdown menu for model selection - M_ORD, T_ORD = ["yolo11n", "yolo11s", "yolo11m", "yolo11l", "yolo11x"], ["", "-seg", "-pose", "-obb", "-cls"] - available_models = sorted( - [ - x.replace("yolo", "YOLO") - for x in GITHUB_ASSETS_STEMS - if any(x.startswith(b) for b in M_ORD) and "grayscale" not in x - ], - key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")), - ) - if self.model_path: # If user provided the custom model, insert model without suffix as *.pt is added later - available_models.insert(0, self.model_path.split(".pt", 1)[0]) - selected_model = self.st.sidebar.selectbox("Model", available_models) - - with self.st.spinner("Model is downloading..."): - self.model = YOLO(f"{selected_model.lower()}.pt") # Load the YOLO model - class_names = list(self.model.names.values()) # Convert dictionary to list of class names - self.st.success("Model loaded successfully!") - - # Multiselect box with class names and get indices of selected classes - selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3]) - self.selected_ind = [class_names.index(option) for option in selected_classes] - - if not isinstance(self.selected_ind, list): # Ensure selected_options is a list - self.selected_ind = list(self.selected_ind) - - def image_inference(self) -> None: - """Perform inference on uploaded images.""" - for idx, img_info in enumerate(self.img_file_names): - img_path = img_info["path"] - image = cv2.imread(img_path) # Load and display the original image - if image is not None: - self.st.markdown(f"#### Processed: {img_info['name']}") - col1, col2 = self.st.columns(2) - with col1: - self.st.image(image, channels="BGR", caption="Original Image") - results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind) - annotated_image = results[0].plot() - with col2: - self.st.image(annotated_image, channels="BGR", caption="Predicted Image") - try: # Clean up temporary file - os.unlink(img_path) - except FileNotFoundError: - pass # File doesn't exist, ignore - else: - self.st.error("Could not load the uploaded image.") - - def inference(self) -> None: - """Perform real-time object detection inference on video or webcam feed.""" - self.web_ui() # Initialize the web interface - self.sidebar() # Create the sidebar - self.source_upload() # Upload the video source - self.configure() # Configure the app - - if self.st.sidebar.button("Start"): - if self.source == "image": - if self.img_file_names: - self.image_inference() - else: - self.st.info("Please upload an image file to perform inference.") - return - - stop_button = self.st.sidebar.button("Stop") # Button to stop the inference - cap = cv2.VideoCapture(self.vid_file_name) # Capture the video - if not cap.isOpened(): - self.st.error("Could not open webcam or video source.") - return - - while cap.isOpened(): - success, frame = cap.read() - if not success: - self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.") - break - - # Process frame with model - if self.enable_trk: - results = self.model.track( - frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True - ) - else: - results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind) - - annotated_frame = results[0].plot() # Add annotations on frame - - if stop_button: - cap.release() # Release the capture - self.st.stop() # Stop streamlit app - - self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame - self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed - - cap.release() # Release the capture - cv2.destroyAllWindows() # Destroy all OpenCV windows - - -if __name__ == "__main__": - import sys # Import the sys module for accessing command-line arguments - - # Check if a model name is provided as a command-line argument - args = len(sys.argv) - model = sys.argv[1] if args > 1 else None # Assign first argument as the model name if provided - # Create an instance of the Inference class and run inference - Inference(model=model).inference() diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/templates/similarity-search.html b/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/templates/similarity-search.html deleted file mode 100644 index 6a24179..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/templates/similarity-search.html +++ /dev/null @@ -1,167 +0,0 @@ - - - - - - - - - Semantic Image Search - - - - - -
- Ultralytics Logo -
-

Semantic Image Search with AI

- - -
- - - {% if results %} -
- - - -
- {% endif %} -
- - -
- {% for img in results %} -
- Result Image -
- {% endfor %} -
- - diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/trackzone.py b/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/trackzone.py deleted file mode 100644 index b437769..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/trackzone.py +++ /dev/null @@ -1,91 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from typing import Any - -import cv2 -import numpy as np - -from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults -from ultralytics.utils.plotting import colors - - -class TrackZone(BaseSolution): - """ - A class to manage region-based object tracking in a video stream. - - This class extends the BaseSolution class and provides functionality for tracking objects within a specific region - defined by a polygonal area. Objects outside the region are excluded from tracking. - - Attributes: - region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points. - line_width (int): Width of the lines used for drawing bounding boxes and region boundaries. - names (List[str]): List of class names that the model can detect. - boxes (List[np.ndarray]): Bounding boxes of tracked objects. - track_ids (List[int]): Unique identifiers for each tracked object. - clss (List[int]): Class indices of tracked objects. - - Methods: - process: Process each frame of the video, applying region-based tracking. - extract_tracks: Extract tracking information from the input frame. - display_output: Display the processed output. - - Examples: - >>> tracker = TrackZone() - >>> frame = cv2.imread("frame.jpg") - >>> results = tracker.process(frame) - >>> cv2.imshow("Tracked Frame", results.plot_im) - """ - - def __init__(self, **kwargs: Any) -> None: - """ - Initialize the TrackZone class for tracking objects within a defined region in video streams. - - Args: - **kwargs (Any): Additional keyword arguments passed to the parent class. - """ - super().__init__(**kwargs) - default_region = [(75, 75), (565, 75), (565, 285), (75, 285)] - self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32)) - self.mask = None - - def process(self, im0: np.ndarray) -> SolutionResults: - """ - Process the input frame to track objects within a defined region. - - This method initializes the annotator, creates a mask for the specified region, extracts tracks - only from the masked area, and updates tracking information. Objects outside the region are ignored. - - Args: - im0 (np.ndarray): The input image or frame to be processed. - - Returns: - (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the - total number of tracked objects within the defined region. - - Examples: - >>> tracker = TrackZone() - >>> frame = cv2.imread("path/to/image.jpg") - >>> results = tracker.process(frame) - """ - annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator - - if self.mask is None: # Create a mask for the region - self.mask = np.zeros_like(im0[:, :, 0]) - cv2.fillPoly(self.mask, [self.region], 255) - masked_frame = cv2.bitwise_and(im0, im0, mask=self.mask) - self.extract_tracks(masked_frame) - - # Draw the region boundary - cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2) - - # Iterate over boxes, track ids, classes indexes list and draw bounding boxes - for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs): - annotator.box_label( - box, label=self.adjust_box_label(cls, conf, track_id=track_id), color=colors(track_id, True) - ) - - plot_im = annotator.result() - self.display_output(plot_im) # Display output with base class function - - # Return a SolutionResults - return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/vision_eye.py b/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/vision_eye.py deleted file mode 100644 index 27ecfc7..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/solutions/vision_eye.py +++ /dev/null @@ -1,70 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from typing import Any - -from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults -from ultralytics.utils.plotting import colors - - -class VisionEye(BaseSolution): - """ - A class to manage object detection and vision mapping in images or video streams. - - This class extends the BaseSolution class and provides functionality for detecting objects, - mapping vision points, and annotating results with bounding boxes and labels. - - Attributes: - vision_point (Tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks. - - Methods: - process: Process the input image to detect objects, annotate them, and apply vision mapping. - - Examples: - >>> vision_eye = VisionEye() - >>> frame = cv2.imread("frame.jpg") - >>> results = vision_eye.process(frame) - >>> print(f"Total detected instances: {results.total_tracks}") - """ - - def __init__(self, **kwargs: Any) -> None: - """ - Initialize the VisionEye class for detecting objects and applying vision mapping. - - Args: - **kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point. - """ - super().__init__(**kwargs) - # Set the vision point where the system will view objects and draw tracks - self.vision_point = self.CFG["vision_point"] - - def process(self, im0) -> SolutionResults: - """ - Perform object detection, vision mapping, and annotation on the input image. - - Args: - im0 (np.ndarray): The input image for detection and annotation. - - Returns: - (SolutionResults): Object containing the annotated image and tracking statistics. - - plot_im: Annotated output image with bounding boxes and vision mapping - - total_tracks: Number of tracked objects in the frame - - Examples: - >>> vision_eye = VisionEye() - >>> frame = cv2.imread("image.jpg") - >>> results = vision_eye.process(frame) - >>> print(f"Detected {results.total_tracks} objects") - """ - self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks) - annotator = SolutionAnnotator(im0, self.line_width) - - for cls, t_id, box, conf in zip(self.clss, self.track_ids, self.boxes, self.confs): - # Annotate the image with bounding boxes, labels, and vision mapping - annotator.box_label(box, label=self.adjust_box_label(cls, conf, t_id), color=colors(int(t_id), True)) - annotator.visioneye(box, self.vision_point) - - plot_im = annotator.result() - self.display_output(plot_im) # Display the annotated output using the base class function - - # Return a SolutionResults object with the annotated image and tracking statistics - return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids)) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/README.md b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/README.md deleted file mode 100644 index de6acbd..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/README.md +++ /dev/null @@ -1,295 +0,0 @@ -Ultralytics logo - -# Multi-Object Tracking with Ultralytics YOLO - -Ultralytics YOLO trackers visualization - -[Object tracking](https://www.ultralytics.com/glossary/object-tracking), a key aspect of [video analytics](https://en.wikipedia.org/wiki/Video_content_analysis), involves identifying the location and class of objects within video frames and assigning a unique ID to each detected object as it moves. This capability enables a wide range of applications, from surveillance and security systems to [real-time](https://www.ultralytics.com/glossary/real-time-inference) sports analysis and autonomous vehicle navigation. Learn more about tracking on our [tracking documentation page](https://docs.ultralytics.com/modes/track/). - -## 🎯 Why Choose Ultralytics YOLO for Object Tracking? - -Ultralytics YOLO trackers provide output consistent with standard [object detection](https://docs.ultralytics.com/tasks/detect/) but add persistent object IDs. This simplifies the process of tracking objects in video streams and performing subsequent analyses. Here’s why Ultralytics YOLO is an excellent choice for your object tracking needs: - -- **Efficiency:** Process video streams in real-time without sacrificing accuracy. -- **Flexibility:** Supports multiple robust tracking algorithms and configurations. -- **Ease of Use:** Offers straightforward [Python API](https://docs.ultralytics.com/usage/python/) and [CLI](https://docs.ultralytics.com/usage/cli/) options for rapid integration and deployment. -- **Customizability:** Easily integrates with [custom-trained YOLO models](https://docs.ultralytics.com/modes/train/), enabling deployment in specialized, domain-specific applications. - -**Watch:** Object Detection and Tracking with Ultralytics YOLOv8. - -[![Watch the video](https://user-images.githubusercontent.com/26833433/244171528-66a4a68d-cb85-466a-984a-34301616b7a3.png)](https://www.youtube.com/watch?v=hHyHmOtmEgs) - -## ✨ Features at a Glance - -Ultralytics YOLO extends its powerful object detection features to deliver robust and versatile object tracking: - -- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos. -- **Multiple Tracker Support:** Choose from a selection of established tracking algorithms. -- **Customizable Tracker Configurations:** Adapt the tracking algorithm to specific requirements by adjusting various parameters. - -## 🛠️ Available Trackers - -Ultralytics YOLO supports the following tracking algorithms. Enable them by passing the relevant YAML configuration file, such as `tracker=tracker_type.yaml`: - -- **BoT-SORT:** Use [`botsort.yaml`](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/trackers/botsort.yaml) to enable this tracker. Based on the [BoT-SORT paper](https://arxiv.org/abs/2206.14651) and its official [code implementation](https://github.com/NirAharon/BoT-SORT). -- **ByteTrack:** Use [`bytetrack.yaml`](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/trackers/bytetrack.yaml) to enable this tracker. Based on the [ByteTrack paper](https://arxiv.org/abs/2110.06864) and its official [code implementation](https://github.com/FoundationVision/ByteTrack). - -The default tracker is **BoT-SORT**. - -## ⚙️ Usage - -To run the tracker on video streams, use a trained Detect, Segment, or Pose model like [Ultralytics YOLO11n](https://docs.ultralytics.com/models/yolo11/), YOLO11n-seg, or YOLO11n-pose. - -```python -# Python -from ultralytics import YOLO - -# Load an official or custom model -model = YOLO("yolo11n.pt") # Load an official Detect model -# model = YOLO("yolo11n-seg.pt") # Load an official Segment model -# model = YOLO("yolo11n-pose.pt") # Load an official Pose model -# model = YOLO("path/to/best.pt") # Load a custom trained model - -# Perform tracking with the model -results = model.track(source="https://youtu.be/LNwODJXcvt4", show=True) # Tracking with default tracker -# results = model.track(source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml") # Tracking with ByteTrack tracker -``` - -```bash -# CLI -# Perform tracking with various models using the command line interface -yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model -# yolo track model=yolo11n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model -# yolo track model=yolo11n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model -# yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model - -# Track using ByteTrack tracker -# yolo track model=path/to/best.pt tracker="bytetrack.yaml" -``` - -As shown above, tracking is available for all [Detect](https://docs.ultralytics.com/tasks/detect/), [Segment](https://docs.ultralytics.com/tasks/segment/), and [Pose](https://docs.ultralytics.com/tasks/pose/) models when run on videos or streaming sources. - -## 🔧 Configuration - -### Tracking Arguments - -Tracking configuration shares properties with the Predict mode, such as `conf` (confidence threshold), `iou` ([Intersection over Union](https://www.ultralytics.com/glossary/intersection-over-union-iou) threshold), and `show` (display results). For additional configurations, refer to the [Predict mode documentation](https://docs.ultralytics.com/modes/predict/). - -```python -# Python -from ultralytics import YOLO - -# Configure the tracking parameters and run the tracker -model = YOLO("yolo11n.pt") -results = model.track(source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True) -``` - -```bash -# CLI -# Configure tracking parameters and run the tracker using the command line interface -yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3 iou=0.5 show -``` - -### Tracker Selection - -Ultralytics allows you to use a modified tracker configuration file. Create a copy of a tracker config file (e.g., `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and adjust any configurations (except `tracker_type`) according to your needs. - -```python -# Python -from ultralytics import YOLO - -# Load the model and run the tracker with a custom configuration file -model = YOLO("yolo11n.pt") -results = model.track(source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml") -``` - -```bash -# CLI -# Load the model and run the tracker with a custom configuration file using the command line interface -yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml' -``` - -For a comprehensive list of tracking arguments, consult the [Tracking Configuration files](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) in the repository. - -## 🐍 Python Examples - -### Persisting Tracks Loop - -This Python script uses [OpenCV (`cv2`)](https://opencv.org/) and Ultralytics YOLO11 to perform object tracking on video frames. Ensure you have installed the necessary packages (`opencv-python` and `ultralytics`). The [`persist=True`](https://docs.ultralytics.com/modes/predict/#tracking) argument indicates that the current frame is the next in a sequence, allowing the tracker to maintain track continuity from the previous frame. - -```python -# Python -import cv2 - -from ultralytics import YOLO - -# Load the YOLO11 model -model = YOLO("yolo11n.pt") - -# Open the video file -video_path = "path/to/video.mp4" -cap = cv2.VideoCapture(video_path) - -# Loop through the video frames -while cap.isOpened(): - # Read a frame from the video - success, frame = cap.read() - - if success: - # Run YOLO11 tracking on the frame, persisting tracks between frames - results = model.track(frame, persist=True) - - # Visualize the results on the frame - annotated_frame = results[0].plot() - - # Display the annotated frame - cv2.imshow("YOLO11 Tracking", annotated_frame) - - # Break the loop if 'q' is pressed - if cv2.waitKey(1) & 0xFF == ord("q"): - break - else: - # Break the loop if the end of the video is reached - break - -# Release the video capture object and close the display window -cap.release() -cv2.destroyAllWindows() -``` - -Note the use of `model.track(frame)` instead of `model(frame)`, which specifically enables object tracking. This script processes each video frame, visualizes the tracking results, and displays them. Press 'q' to exit the loop. - -### Plotting Tracks Over Time - -Visualizing object tracks across consecutive frames offers valuable insights into movement patterns within a video. Ultralytics YOLO11 makes plotting these tracks efficient. - -The following example demonstrates how to use YOLO11's tracking capabilities to plot the movement of detected objects. The script opens a video, reads it frame by frame, and uses the YOLO model built on [PyTorch](https://pytorch.org/) to identify and track objects. By storing the center points of the detected [bounding boxes](https://www.ultralytics.com/glossary/bounding-box) and connecting them, we can draw lines representing the paths of tracked objects using [NumPy](https://numpy.org/) for numerical operations. - -```python -# Python -from collections import defaultdict - -import cv2 -import numpy as np - -from ultralytics import YOLO - -# Load the YOLO11 model -model = YOLO("yolo11n.pt") - -# Open the video file -video_path = "path/to/video.mp4" -cap = cv2.VideoCapture(video_path) - -# Store the track history -track_history = defaultdict(lambda: []) - -# Loop through the video frames -while cap.isOpened(): - # Read a frame from the video - success, frame = cap.read() - - if success: - # Run YOLO11 tracking on the frame, persisting tracks between frames - result = model.track(frame, persist=True)[0] - - # Get the boxes and track IDs - if result.boxes and result.boxes.is_track: - boxes = result.boxes.xywh.cpu() - track_ids = result.boxes.id.int().cpu().tolist() - - # Visualize the result on the frame - frame = result.plot() - - # Plot the tracks - for box, track_id in zip(boxes, track_ids): - x, y, w, h = box - track = track_history[track_id] - track.append((float(x), float(y))) # x, y center point - if len(track) > 30: # retain 30 tracks for 30 frames - track.pop(0) - - # Draw the tracking lines - points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) - cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=10) - - # Display the annotated frame - cv2.imshow("YOLO11 Tracking", frame) - - # Break the loop if 'q' is pressed - if cv2.waitKey(1) & 0xFF == ord("q"): - break - else: - # Break the loop if the end of the video is reached - break - -# Release the video capture object and close the display window -cap.release() -cv2.destroyAllWindows() -``` - -### Multithreaded Tracking - -Multithreaded tracking allows running object tracking on multiple video streams simultaneously, which is highly beneficial for systems handling inputs from several cameras, improving efficiency through concurrent processing. - -This Python script utilizes Python's [`threading`](https://docs.python.org/3/library/threading.html) module for concurrent tracker execution. Each thread manages tracking for a single video file. - -The `run_tracker_in_thread` function accepts parameters like the video file path, model, and a unique window index. It contains the main tracking loop, reading frames, running the tracker, and displaying results in a dedicated window. - -This example uses two models, `yolo11n.pt` and `yolo11n-seg.pt`, tracking objects in `video_file1` and `video_file2`, respectively. - -Setting `daemon=True` in `threading.Thread` ensures threads exit when the main program finishes. Threads are started with `start()` and the main thread waits for their completion using `join()`. - -Finally, `cv2.destroyAllWindows()` closes all OpenCV windows after the threads finish. - -```python -# Python -import threading - -import cv2 - -from ultralytics import YOLO - -# Define model names and video sources -MODEL_NAMES = ["yolo11n.pt", "yolo11n-seg.pt"] -SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam - - -def run_tracker_in_thread(model_name, filename): - """ - Run YOLO tracker in its own thread for concurrent processing. - - Args: - model_name (str): The YOLO11 model object. - filename (str): The path to the video file or the identifier for the webcam/external camera source. - """ - model = YOLO(model_name) - results = model.track(filename, save=True, stream=True) - for r in results: - pass - - -# Create and start tracker threads using a for loop -tracker_threads = [] -for video_file, model_name in zip(SOURCES, MODEL_NAMES): - thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True) - tracker_threads.append(thread) - thread.start() - -# Wait for all tracker threads to finish -for thread in tracker_threads: - thread.join() - -# Clean up and close windows -cv2.destroyAllWindows() -``` - -This setup can be easily scaled to handle more video streams by creating additional threads following the same pattern. Explore more applications in our [blog post on object tracking](https://www.ultralytics.com/blog/object-detection-and-tracking-with-ultralytics-yolov8). - -## 🤝 Contribute New Trackers - -Are you experienced in multi-object tracking and have implemented or adapted an algorithm with Ultralytics YOLO? We encourage you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your contributions can help expand the tracking solutions available within the Ultralytics [ecosystem](https://docs.ultralytics.com/). - -To contribute, please review our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) for instructions on submitting a [Pull Request (PR)](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) 🛠️. We look forward to your contributions! - -Let's work together to enhance the tracking capabilities of Ultralytics YOLO and provide more powerful tools for the [computer vision](https://www.ultralytics.com/glossary/computer-vision-cv) and [deep learning](https://www.ultralytics.com/glossary/deep-learning-dl) community 🙏! diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/__init__.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/__init__.py deleted file mode 100644 index 2919511..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from .bot_sort import BOTSORT -from .byte_tracker import BYTETracker -from .track import register_tracker - -__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/basetrack.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/basetrack.py deleted file mode 100644 index d254883..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/basetrack.py +++ /dev/null @@ -1,117 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -"""Module defines the base classes and structures for object tracking in YOLO.""" - -from collections import OrderedDict -from typing import Any - -import numpy as np - - -class TrackState: - """ - Enumeration class representing the possible states of an object being tracked. - - Attributes: - New (int): State when the object is newly detected. - Tracked (int): State when the object is successfully tracked in subsequent frames. - Lost (int): State when the object is no longer tracked. - Removed (int): State when the object is removed from tracking. - - Examples: - >>> state = TrackState.New - >>> if state == TrackState.New: - >>> print("Object is newly detected.") - """ - - New = 0 - Tracked = 1 - Lost = 2 - Removed = 3 - - -class BaseTrack: - """ - Base class for object tracking, providing foundational attributes and methods. - - Attributes: - _count (int): Class-level counter for unique track IDs. - track_id (int): Unique identifier for the track. - is_activated (bool): Flag indicating whether the track is currently active. - state (TrackState): Current state of the track. - history (OrderedDict): Ordered history of the track's states. - features (list): List of features extracted from the object for tracking. - curr_feature (Any): The current feature of the object being tracked. - score (float): The confidence score of the tracking. - start_frame (int): The frame number where tracking started. - frame_id (int): The most recent frame ID processed by the track. - time_since_update (int): Frames passed since the last update. - location (tuple): The location of the object in the context of multi-camera tracking. - - Methods: - end_frame: Returns the ID of the last frame where the object was tracked. - next_id: Increments and returns the next global track ID. - activate: Abstract method to activate the track. - predict: Abstract method to predict the next state of the track. - update: Abstract method to update the track with new data. - mark_lost: Marks the track as lost. - mark_removed: Marks the track as removed. - reset_id: Resets the global track ID counter. - - Examples: - Initialize a new track and mark it as lost: - >>> track = BaseTrack() - >>> track.mark_lost() - >>> print(track.state) # Output: 2 (TrackState.Lost) - """ - - _count = 0 - - def __init__(self): - """Initialize a new track with a unique ID and foundational tracking attributes.""" - self.track_id = 0 - self.is_activated = False - self.state = TrackState.New - self.history = OrderedDict() - self.features = [] - self.curr_feature = None - self.score = 0 - self.start_frame = 0 - self.frame_id = 0 - self.time_since_update = 0 - self.location = (np.inf, np.inf) - - @property - def end_frame(self) -> int: - """Return the ID of the most recent frame where the object was tracked.""" - return self.frame_id - - @staticmethod - def next_id() -> int: - """Increment and return the next unique global track ID for object tracking.""" - BaseTrack._count += 1 - return BaseTrack._count - - def activate(self, *args: Any) -> None: - """Activate the track with provided arguments, initializing necessary attributes for tracking.""" - raise NotImplementedError - - def predict(self) -> None: - """Predict the next state of the track based on the current state and tracking model.""" - raise NotImplementedError - - def update(self, *args: Any, **kwargs: Any) -> None: - """Update the track with new observations and data, modifying its state and attributes accordingly.""" - raise NotImplementedError - - def mark_lost(self) -> None: - """Mark the track as lost by updating its state to TrackState.Lost.""" - self.state = TrackState.Lost - - def mark_removed(self) -> None: - """Mark the track as removed by setting its state to TrackState.Removed.""" - self.state = TrackState.Removed - - @staticmethod - def reset_id() -> None: - """Reset the global track ID counter to its initial value.""" - BaseTrack._count = 0 diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/bot_sort.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/bot_sort.py deleted file mode 100644 index 1bbce2b..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/bot_sort.py +++ /dev/null @@ -1,272 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from collections import deque -from typing import Any, List, Optional - -import numpy as np -import torch - -from ultralytics.utils.ops import xywh2xyxy -from ultralytics.utils.plotting import save_one_box - -from .basetrack import TrackState -from .byte_tracker import BYTETracker, STrack -from .utils import matching -from .utils.gmc import GMC -from .utils.kalman_filter import KalmanFilterXYWH - - -class BOTrack(STrack): - """ - An extended version of the STrack class for YOLO, adding object tracking features. - - This class extends the STrack class to include additional functionalities for object tracking, such as feature - smoothing, Kalman filter prediction, and reactivation of tracks. - - Attributes: - shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack. - smooth_feat (np.ndarray): Smoothed feature vector. - curr_feat (np.ndarray): Current feature vector. - features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`. - alpha (float): Smoothing factor for the exponential moving average of features. - mean (np.ndarray): The mean state of the Kalman filter. - covariance (np.ndarray): The covariance matrix of the Kalman filter. - - Methods: - update_features: Update features vector and smooth it using exponential moving average. - predict: Predict the mean and covariance using Kalman filter. - re_activate: Reactivate a track with updated features and optionally new ID. - update: Update the track with new detection and frame ID. - tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`. - multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter. - convert_coords: Convert tlwh bounding box coordinates to xywh format. - tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`. - - Examples: - Create a BOTrack instance and update its features - >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128)) - >>> bo_track.predict() - >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128)) - >>> bo_track.update(new_track, frame_id=2) - """ - - shared_kalman = KalmanFilterXYWH() - - def __init__( - self, xywh: np.ndarray, score: float, cls: int, feat: Optional[np.ndarray] = None, feat_history: int = 50 - ): - """ - Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features. - - Args: - xywh (np.ndarray): Bounding box coordinates in xywh format (center x, center y, width, height). - score (float): Confidence score of the detection. - cls (int): Class ID of the detected object. - feat (np.ndarray, optional): Feature vector associated with the detection. - feat_history (int): Maximum length of the feature history deque. - - Examples: - Initialize a BOTrack object with bounding box, score, class ID, and feature vector - >>> xywh = np.array([100, 150, 60, 50]) - >>> score = 0.9 - >>> cls = 1 - >>> feat = np.random.rand(128) - >>> bo_track = BOTrack(xywh, score, cls, feat) - """ - super().__init__(xywh, score, cls) - - self.smooth_feat = None - self.curr_feat = None - if feat is not None: - self.update_features(feat) - self.features = deque([], maxlen=feat_history) - self.alpha = 0.9 - - def update_features(self, feat: np.ndarray) -> None: - """Update the feature vector and apply exponential moving average smoothing.""" - feat /= np.linalg.norm(feat) - self.curr_feat = feat - if self.smooth_feat is None: - self.smooth_feat = feat - else: - self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat - self.features.append(feat) - self.smooth_feat /= np.linalg.norm(self.smooth_feat) - - def predict(self) -> None: - """Predict the object's future state using the Kalman filter to update its mean and covariance.""" - mean_state = self.mean.copy() - if self.state != TrackState.Tracked: - mean_state[6] = 0 - mean_state[7] = 0 - - self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) - - def re_activate(self, new_track: "BOTrack", frame_id: int, new_id: bool = False) -> None: - """Reactivate a track with updated features and optionally assign a new ID.""" - if new_track.curr_feat is not None: - self.update_features(new_track.curr_feat) - super().re_activate(new_track, frame_id, new_id) - - def update(self, new_track: "BOTrack", frame_id: int) -> None: - """Update the track with new detection information and the current frame ID.""" - if new_track.curr_feat is not None: - self.update_features(new_track.curr_feat) - super().update(new_track, frame_id) - - @property - def tlwh(self) -> np.ndarray: - """Return the current bounding box position in `(top left x, top left y, width, height)` format.""" - if self.mean is None: - return self._tlwh.copy() - ret = self.mean[:4].copy() - ret[:2] -= ret[2:] / 2 - return ret - - @staticmethod - def multi_predict(stracks: List["BOTrack"]) -> None: - """Predict the mean and covariance for multiple object tracks using a shared Kalman filter.""" - if len(stracks) <= 0: - return - multi_mean = np.asarray([st.mean.copy() for st in stracks]) - multi_covariance = np.asarray([st.covariance for st in stracks]) - for i, st in enumerate(stracks): - if st.state != TrackState.Tracked: - multi_mean[i][6] = 0 - multi_mean[i][7] = 0 - multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance) - for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): - stracks[i].mean = mean - stracks[i].covariance = cov - - def convert_coords(self, tlwh: np.ndarray) -> np.ndarray: - """Convert tlwh bounding box coordinates to xywh format.""" - return self.tlwh_to_xywh(tlwh) - - @staticmethod - def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray: - """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format.""" - ret = np.asarray(tlwh).copy() - ret[:2] += ret[2:] / 2 - return ret - - -class BOTSORT(BYTETracker): - """ - An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm. - - Attributes: - proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections. - appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections. - encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled. - gmc (GMC): An instance of the GMC algorithm for data association. - args (Any): Parsed command-line arguments containing tracking parameters. - - Methods: - get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking. - init_track: Initialize track with detections, scores, and classes. - get_dists: Get distances between tracks and detections using IoU and (optionally) ReID. - multi_predict: Predict and track multiple objects with a YOLO model. - reset: Reset the BOTSORT tracker to its initial state. - - Examples: - Initialize BOTSORT and process detections - >>> bot_sort = BOTSORT(args, frame_rate=30) - >>> bot_sort.init_track(dets, scores, cls, img) - >>> bot_sort.multi_predict(tracks) - - Note: - The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args. - """ - - def __init__(self, args: Any, frame_rate: int = 30): - """ - Initialize BOTSORT object with ReID module and GMC algorithm. - - Args: - args (Any): Parsed command-line arguments containing tracking parameters. - frame_rate (int): Frame rate of the video being processed. - - Examples: - Initialize BOTSORT with command-line arguments and a specified frame rate: - >>> args = parse_args() - >>> bot_sort = BOTSORT(args, frame_rate=30) - """ - super().__init__(args, frame_rate) - self.gmc = GMC(method=args.gmc_method) - - # ReID module - self.proximity_thresh = args.proximity_thresh - self.appearance_thresh = args.appearance_thresh - self.encoder = ( - (lambda feats, s: [f.cpu().numpy() for f in feats]) # native features do not require any model - if args.with_reid and self.args.model == "auto" - else ReID(args.model) - if args.with_reid - else None - ) - - def get_kalmanfilter(self) -> KalmanFilterXYWH: - """Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process.""" - return KalmanFilterXYWH() - - def init_track(self, results, img: Optional[np.ndarray] = None) -> List[BOTrack]: - """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features.""" - if len(results) == 0: - return [] - bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh - bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) - if self.args.with_reid and self.encoder is not None: - features_keep = self.encoder(img, bboxes) - return [BOTrack(xywh, s, c, f) for (xywh, s, c, f) in zip(bboxes, results.conf, results.cls, features_keep)] - else: - return [BOTrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)] - - def get_dists(self, tracks: List[BOTrack], detections: List[BOTrack]) -> np.ndarray: - """Calculate distances between tracks and detections using IoU and optionally ReID embeddings.""" - dists = matching.iou_distance(tracks, detections) - dists_mask = dists > (1 - self.proximity_thresh) - - if self.args.fuse_score: - dists = matching.fuse_score(dists, detections) - - if self.args.with_reid and self.encoder is not None: - emb_dists = matching.embedding_distance(tracks, detections) / 2.0 - emb_dists[emb_dists > (1 - self.appearance_thresh)] = 1.0 - emb_dists[dists_mask] = 1.0 - dists = np.minimum(dists, emb_dists) - return dists - - def multi_predict(self, tracks: List[BOTrack]) -> None: - """Predict the mean and covariance of multiple object tracks using a shared Kalman filter.""" - BOTrack.multi_predict(tracks) - - def reset(self) -> None: - """Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states.""" - super().reset() - self.gmc.reset_params() - - -class ReID: - """YOLO model as encoder for re-identification.""" - - def __init__(self, model: str): - """ - Initialize encoder for re-identification. - - Args: - model (str): Path to the YOLO model for re-identification. - """ - from ultralytics import YOLO - - self.model = YOLO(model) - self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init - - def __call__(self, img: np.ndarray, dets: np.ndarray) -> List[np.ndarray]: - """Extract embeddings for detected objects.""" - feats = self.model.predictor( - [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))] - ) - if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]: - feats = feats[0] # batched prediction with non-PyTorch backend - return [f.cpu().numpy() for f in feats] diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/byte_tracker.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/byte_tracker.py deleted file mode 100644 index ea7d178..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/byte_tracker.py +++ /dev/null @@ -1,483 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from typing import Any, List, Optional, Tuple - -import numpy as np - -from ..utils import LOGGER -from ..utils.ops import xywh2ltwh -from .basetrack import BaseTrack, TrackState -from .utils import matching -from .utils.kalman_filter import KalmanFilterXYAH - - -class STrack(BaseTrack): - """ - Single object tracking representation that uses Kalman filtering for state estimation. - - This class is responsible for storing all the information regarding individual tracklets and performs state updates - and predictions based on Kalman filter. - - Attributes: - shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction. - _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box. - kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track. - mean (np.ndarray): Mean state estimate vector. - covariance (np.ndarray): Covariance of state estimate. - is_activated (bool): Boolean flag indicating if the track has been activated. - score (float): Confidence score of the track. - tracklet_len (int): Length of the tracklet. - cls (Any): Class label for the object. - idx (int): Index or identifier for the object. - frame_id (int): Current frame ID. - start_frame (int): Frame where the object was first detected. - angle (float | None): Optional angle information for oriented bounding boxes. - - Methods: - predict: Predict the next state of the object using Kalman filter. - multi_predict: Predict the next states for multiple tracks. - multi_gmc: Update multiple track states using a homography matrix. - activate: Activate a new tracklet. - re_activate: Reactivate a previously lost tracklet. - update: Update the state of a matched track. - convert_coords: Convert bounding box to x-y-aspect-height format. - tlwh_to_xyah: Convert tlwh bounding box to xyah format. - - Examples: - Initialize and activate a new track - >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person") - >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1) - """ - - shared_kalman = KalmanFilterXYAH() - - def __init__(self, xywh: List[float], score: float, cls: Any): - """ - Initialize a new STrack instance. - - Args: - xywh (List[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where - (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id. - score (float): Confidence score of the detection. - cls (Any): Class label for the detected object. - - Examples: - >>> xywh = [100.0, 150.0, 50.0, 75.0, 1] - >>> score = 0.9 - >>> cls = "person" - >>> track = STrack(xywh, score, cls) - """ - super().__init__() - # xywh+idx or xywha+idx - assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}" - self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32) - self.kalman_filter = None - self.mean, self.covariance = None, None - self.is_activated = False - - self.score = score - self.tracklet_len = 0 - self.cls = cls - self.idx = xywh[-1] - self.angle = xywh[4] if len(xywh) == 6 else None - - def predict(self): - """Predict the next state (mean and covariance) of the object using the Kalman filter.""" - mean_state = self.mean.copy() - if self.state != TrackState.Tracked: - mean_state[7] = 0 - self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance) - - @staticmethod - def multi_predict(stracks: List["STrack"]): - """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances.""" - if len(stracks) <= 0: - return - multi_mean = np.asarray([st.mean.copy() for st in stracks]) - multi_covariance = np.asarray([st.covariance for st in stracks]) - for i, st in enumerate(stracks): - if st.state != TrackState.Tracked: - multi_mean[i][7] = 0 - multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance) - for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): - stracks[i].mean = mean - stracks[i].covariance = cov - - @staticmethod - def multi_gmc(stracks: List["STrack"], H: np.ndarray = np.eye(2, 3)): - """Update state tracks positions and covariances using a homography matrix for multiple tracks.""" - if len(stracks) > 0: - multi_mean = np.asarray([st.mean.copy() for st in stracks]) - multi_covariance = np.asarray([st.covariance for st in stracks]) - - R = H[:2, :2] - R8x8 = np.kron(np.eye(4, dtype=float), R) - t = H[:2, 2] - - for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)): - mean = R8x8.dot(mean) - mean[:2] += t - cov = R8x8.dot(cov).dot(R8x8.transpose()) - - stracks[i].mean = mean - stracks[i].covariance = cov - - def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int): - """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance.""" - self.kalman_filter = kalman_filter - self.track_id = self.next_id() - self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh)) - - self.tracklet_len = 0 - self.state = TrackState.Tracked - if frame_id == 1: - self.is_activated = True - self.frame_id = frame_id - self.start_frame = frame_id - - def re_activate(self, new_track: "STrack", frame_id: int, new_id: bool = False): - """Reactivate a previously lost track using new detection data and update its state and attributes.""" - self.mean, self.covariance = self.kalman_filter.update( - self.mean, self.covariance, self.convert_coords(new_track.tlwh) - ) - self.tracklet_len = 0 - self.state = TrackState.Tracked - self.is_activated = True - self.frame_id = frame_id - if new_id: - self.track_id = self.next_id() - self.score = new_track.score - self.cls = new_track.cls - self.angle = new_track.angle - self.idx = new_track.idx - - def update(self, new_track: "STrack", frame_id: int): - """ - Update the state of a matched track. - - Args: - new_track (STrack): The new track containing updated information. - frame_id (int): The ID of the current frame. - - Examples: - Update the state of a track with new detection information - >>> track = STrack([100, 200, 50, 80, 0.9, 1]) - >>> new_track = STrack([105, 205, 55, 85, 0.95, 1]) - >>> track.update(new_track, 2) - """ - self.frame_id = frame_id - self.tracklet_len += 1 - - new_tlwh = new_track.tlwh - self.mean, self.covariance = self.kalman_filter.update( - self.mean, self.covariance, self.convert_coords(new_tlwh) - ) - self.state = TrackState.Tracked - self.is_activated = True - - self.score = new_track.score - self.cls = new_track.cls - self.angle = new_track.angle - self.idx = new_track.idx - - def convert_coords(self, tlwh: np.ndarray) -> np.ndarray: - """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent.""" - return self.tlwh_to_xyah(tlwh) - - @property - def tlwh(self) -> np.ndarray: - """Get the bounding box in top-left-width-height format from the current state estimate.""" - if self.mean is None: - return self._tlwh.copy() - ret = self.mean[:4].copy() - ret[2] *= ret[3] - ret[:2] -= ret[2:] / 2 - return ret - - @property - def xyxy(self) -> np.ndarray: - """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format.""" - ret = self.tlwh.copy() - ret[2:] += ret[:2] - return ret - - @staticmethod - def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray: - """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format.""" - ret = np.asarray(tlwh).copy() - ret[:2] += ret[2:] / 2 - ret[2] /= ret[3] - return ret - - @property - def xywh(self) -> np.ndarray: - """Get the current position of the bounding box in (center x, center y, width, height) format.""" - ret = np.asarray(self.tlwh).copy() - ret[:2] += ret[2:] / 2 - return ret - - @property - def xywha(self) -> np.ndarray: - """Get position in (center x, center y, width, height, angle) format, warning if angle is missing.""" - if self.angle is None: - LOGGER.warning("`angle` attr not found, returning `xywh` instead.") - return self.xywh - return np.concatenate([self.xywh, self.angle[None]]) - - @property - def result(self) -> List[float]: - """Get the current tracking results in the appropriate bounding box format.""" - coords = self.xyxy if self.angle is None else self.xywha - return coords.tolist() + [self.track_id, self.score, self.cls, self.idx] - - def __repr__(self) -> str: - """Return a string representation of the STrack object including start frame, end frame, and track ID.""" - return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})" - - -class BYTETracker: - """ - BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking. - - This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a - video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for - predicting the new object locations, and performs data association. - - Attributes: - tracked_stracks (List[STrack]): List of successfully activated tracks. - lost_stracks (List[STrack]): List of lost tracks. - removed_stracks (List[STrack]): List of removed tracks. - frame_id (int): The current frame ID. - args (Namespace): Command-line arguments. - max_time_lost (int): The maximum frames for a track to be considered as 'lost'. - kalman_filter (KalmanFilterXYAH): Kalman Filter object. - - Methods: - update: Update object tracker with new detections. - get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes. - init_track: Initialize object tracking with detections. - get_dists: Calculate the distance between tracks and detections. - multi_predict: Predict the location of tracks. - reset_id: Reset the ID counter of STrack. - reset: Reset the tracker by clearing all tracks. - joint_stracks: Combine two lists of stracks. - sub_stracks: Filter out the stracks present in the second list from the first list. - remove_duplicate_stracks: Remove duplicate stracks based on IoU. - - Examples: - Initialize BYTETracker and update with detection results - >>> tracker = BYTETracker(args, frame_rate=30) - >>> results = yolo_model.detect(image) - >>> tracked_objects = tracker.update(results) - """ - - def __init__(self, args, frame_rate: int = 30): - """ - Initialize a BYTETracker instance for object tracking. - - Args: - args (Namespace): Command-line arguments containing tracking parameters. - frame_rate (int): Frame rate of the video sequence. - - Examples: - Initialize BYTETracker with command-line arguments and a frame rate of 30 - >>> args = Namespace(track_buffer=30) - >>> tracker = BYTETracker(args, frame_rate=30) - """ - self.tracked_stracks = [] # type: List[STrack] - self.lost_stracks = [] # type: List[STrack] - self.removed_stracks = [] # type: List[STrack] - - self.frame_id = 0 - self.args = args - self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer) - self.kalman_filter = self.get_kalmanfilter() - self.reset_id() - - def update(self, results, img: Optional[np.ndarray] = None, feats: Optional[np.ndarray] = None) -> np.ndarray: - """Update the tracker with new detections and return the current list of tracked objects.""" - self.frame_id += 1 - activated_stracks = [] - refind_stracks = [] - lost_stracks = [] - removed_stracks = [] - - scores = results.conf - remain_inds = scores >= self.args.track_high_thresh - inds_low = scores > self.args.track_low_thresh - inds_high = scores < self.args.track_high_thresh - - inds_second = inds_low & inds_high - results_second = results[inds_second] - results = results[remain_inds] - feats_keep = feats_second = img - if feats is not None and len(feats): - feats_keep = feats[remain_inds] - feats_second = feats[inds_second] - - detections = self.init_track(results, feats_keep) - # Add newly detected tracklets to tracked_stracks - unconfirmed = [] - tracked_stracks = [] # type: List[STrack] - for track in self.tracked_stracks: - if not track.is_activated: - unconfirmed.append(track) - else: - tracked_stracks.append(track) - # Step 2: First association, with high score detection boxes - strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks) - # Predict the current location with KF - self.multi_predict(strack_pool) - if hasattr(self, "gmc") and img is not None: - # use try-except here to bypass errors from gmc module - try: - warp = self.gmc.apply(img, results.xyxy) - except Exception: - warp = np.eye(2, 3) - STrack.multi_gmc(strack_pool, warp) - STrack.multi_gmc(unconfirmed, warp) - - dists = self.get_dists(strack_pool, detections) - matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh) - - for itracked, idet in matches: - track = strack_pool[itracked] - det = detections[idet] - if track.state == TrackState.Tracked: - track.update(det, self.frame_id) - activated_stracks.append(track) - else: - track.re_activate(det, self.frame_id, new_id=False) - refind_stracks.append(track) - # Step 3: Second association, with low score detection boxes association the untrack to the low score detections - detections_second = self.init_track(results_second, feats_second) - r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked] - # TODO - dists = matching.iou_distance(r_tracked_stracks, detections_second) - matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5) - for itracked, idet in matches: - track = r_tracked_stracks[itracked] - det = detections_second[idet] - if track.state == TrackState.Tracked: - track.update(det, self.frame_id) - activated_stracks.append(track) - else: - track.re_activate(det, self.frame_id, new_id=False) - refind_stracks.append(track) - - for it in u_track: - track = r_tracked_stracks[it] - if track.state != TrackState.Lost: - track.mark_lost() - lost_stracks.append(track) - # Deal with unconfirmed tracks, usually tracks with only one beginning frame - detections = [detections[i] for i in u_detection] - dists = self.get_dists(unconfirmed, detections) - matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7) - for itracked, idet in matches: - unconfirmed[itracked].update(detections[idet], self.frame_id) - activated_stracks.append(unconfirmed[itracked]) - for it in u_unconfirmed: - track = unconfirmed[it] - track.mark_removed() - removed_stracks.append(track) - # Step 4: Init new stracks - for inew in u_detection: - track = detections[inew] - if track.score < self.args.new_track_thresh: - continue - track.activate(self.kalman_filter, self.frame_id) - activated_stracks.append(track) - # Step 5: Update state - for track in self.lost_stracks: - if self.frame_id - track.end_frame > self.max_time_lost: - track.mark_removed() - removed_stracks.append(track) - - self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked] - self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks) - self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks) - self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks) - self.lost_stracks.extend(lost_stracks) - self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks) - self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks) - self.removed_stracks.extend(removed_stracks) - if len(self.removed_stracks) > 1000: - self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum - - return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32) - - def get_kalmanfilter(self) -> KalmanFilterXYAH: - """Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH.""" - return KalmanFilterXYAH() - - def init_track(self, results, img: Optional[np.ndarray] = None) -> List[STrack]: - """Initialize object tracking with given detections, scores, and class labels using the STrack algorithm.""" - if len(results) == 0: - return [] - bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh - bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1) - return [STrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)] - - def get_dists(self, tracks: List[STrack], detections: List[STrack]) -> np.ndarray: - """Calculate the distance between tracks and detections using IoU and optionally fuse scores.""" - dists = matching.iou_distance(tracks, detections) - if self.args.fuse_score: - dists = matching.fuse_score(dists, detections) - return dists - - def multi_predict(self, tracks: List[STrack]): - """Predict the next states for multiple tracks using Kalman filter.""" - STrack.multi_predict(tracks) - - @staticmethod - def reset_id(): - """Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions.""" - STrack.reset_id() - - def reset(self): - """Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter.""" - self.tracked_stracks = [] # type: List[STrack] - self.lost_stracks = [] # type: List[STrack] - self.removed_stracks = [] # type: List[STrack] - self.frame_id = 0 - self.kalman_filter = self.get_kalmanfilter() - self.reset_id() - - @staticmethod - def joint_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]: - """Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs.""" - exists = {} - res = [] - for t in tlista: - exists[t.track_id] = 1 - res.append(t) - for t in tlistb: - tid = t.track_id - if not exists.get(tid, 0): - exists[tid] = 1 - res.append(t) - return res - - @staticmethod - def sub_stracks(tlista: List[STrack], tlistb: List[STrack]) -> List[STrack]: - """Filter out the stracks present in the second list from the first list.""" - track_ids_b = {t.track_id for t in tlistb} - return [t for t in tlista if t.track_id not in track_ids_b] - - @staticmethod - def remove_duplicate_stracks(stracksa: List[STrack], stracksb: List[STrack]) -> Tuple[List[STrack], List[STrack]]: - """Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance.""" - pdist = matching.iou_distance(stracksa, stracksb) - pairs = np.where(pdist < 0.15) - dupa, dupb = [], [] - for p, q in zip(*pairs): - timep = stracksa[p].frame_id - stracksa[p].start_frame - timeq = stracksb[q].frame_id - stracksb[q].start_frame - if timep > timeq: - dupb.append(q) - else: - dupa.append(p) - resa = [t for i, t in enumerate(stracksa) if i not in dupa] - resb = [t for i, t in enumerate(stracksb) if i not in dupb] - return resa, resb diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/track.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/track.py deleted file mode 100644 index 8720f73..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/track.py +++ /dev/null @@ -1,119 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from functools import partial -from pathlib import Path - -import torch - -from ultralytics.utils import YAML, IterableSimpleNamespace -from ultralytics.utils.checks import check_yaml - -from .bot_sort import BOTSORT -from .byte_tracker import BYTETracker - -# A mapping of tracker types to corresponding tracker classes -TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT} - - -def on_predict_start(predictor: object, persist: bool = False) -> None: - """ - Initialize trackers for object tracking during prediction. - - Args: - predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for. - persist (bool, optional): Whether to persist the trackers if they already exist. - - Examples: - Initialize trackers for a predictor object - >>> predictor = SomePredictorClass() - >>> on_predict_start(predictor, persist=True) - """ - if predictor.args.task == "classify": - raise ValueError("❌ Classification doesn't support 'mode=track'") - - if hasattr(predictor, "trackers") and persist: - return - - tracker = check_yaml(predictor.args.tracker) - cfg = IterableSimpleNamespace(**YAML.load(tracker)) - - if cfg.tracker_type not in {"bytetrack", "botsort"}: - raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'") - - predictor._feats = None # reset in case used earlier - if hasattr(predictor, "_hook"): - predictor._hook.remove() - if cfg.tracker_type == "botsort" and cfg.with_reid and cfg.model == "auto": - from ultralytics.nn.modules.head import Detect - - if not ( - isinstance(predictor.model.model, torch.nn.Module) - and isinstance(predictor.model.model.model[-1], Detect) - and not predictor.model.model.model[-1].end2end - ): - cfg.model = "yolo11n-cls.pt" - else: - # Register hook to extract input of Detect layer - def pre_hook(module, input): - predictor._feats = list(input[0]) # unroll to new list to avoid mutation in forward - - predictor._hook = predictor.model.model.model[-1].register_forward_pre_hook(pre_hook) - - trackers = [] - for _ in range(predictor.dataset.bs): - tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30) - trackers.append(tracker) - if predictor.dataset.mode != "stream": # only need one tracker for other modes - break - predictor.trackers = trackers - predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video - - -def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None: - """ - Postprocess detected boxes and update with object tracking. - - Args: - predictor (object): The predictor object containing the predictions. - persist (bool, optional): Whether to persist the trackers if they already exist. - - Examples: - Postprocess predictions and update with tracking - >>> predictor = YourPredictorClass() - >>> on_predict_postprocess_end(predictor, persist=True) - """ - is_obb = predictor.args.task == "obb" - is_stream = predictor.dataset.mode == "stream" - for i, result in enumerate(predictor.results): - tracker = predictor.trackers[i if is_stream else 0] - vid_path = predictor.save_dir / Path(result.path).name - if not persist and predictor.vid_path[i if is_stream else 0] != vid_path: - tracker.reset() - predictor.vid_path[i if is_stream else 0] = vid_path - - det = (result.obb if is_obb else result.boxes).cpu().numpy() - tracks = tracker.update(det, result.orig_img, getattr(result, "feats", None)) - if len(tracks) == 0: - continue - idx = tracks[:, -1].astype(int) - predictor.results[i] = result[idx] - - update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])} - predictor.results[i].update(**update_args) - - -def register_tracker(model: object, persist: bool) -> None: - """ - Register tracking callbacks to the model for object tracking during prediction. - - Args: - model (object): The model object to register tracking callbacks for. - persist (bool): Whether to persist the trackers if they already exist. - - Examples: - Register tracking callbacks to a YOLO model - >>> model = YOLOModel() - >>> register_tracker(model, persist=True) - """ - model.add_callback("on_predict_start", partial(on_predict_start, persist=persist)) - model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist)) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/__init__.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/__init__.py deleted file mode 100644 index 77a19dc..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/gmc.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/gmc.py deleted file mode 100644 index fc0dd5d..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/gmc.py +++ /dev/null @@ -1,349 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import copy -from typing import List, Optional - -import cv2 -import numpy as np - -from ultralytics.utils import LOGGER - - -class GMC: - """ - Generalized Motion Compensation (GMC) class for tracking and object detection in video frames. - - This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB, - SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency. - - Attributes: - method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. - downscale (int): Factor by which to downscale the frames for processing. - prevFrame (np.ndarray): Previous frame for tracking. - prevKeyPoints (List): Keypoints from the previous frame. - prevDescriptors (np.ndarray): Descriptors from the previous frame. - initializedFirstFrame (bool): Flag indicating if the first frame has been processed. - - Methods: - apply: Apply the chosen method to a raw frame and optionally use provided detections. - apply_ecc: Apply the ECC algorithm to a raw frame. - apply_features: Apply feature-based methods like ORB or SIFT to a raw frame. - apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame. - reset_params: Reset the internal parameters of the GMC object. - - Examples: - Create a GMC object and apply it to a frame - >>> gmc = GMC(method="sparseOptFlow", downscale=2) - >>> frame = np.array([[1, 2, 3], [4, 5, 6]]) - >>> processed_frame = gmc.apply(frame) - >>> print(processed_frame) - array([[1, 2, 3], - [4, 5, 6]]) - """ - - def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None: - """ - Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor. - - Args: - method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'. - downscale (int): Downscale factor for processing frames. - - Examples: - Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2 - >>> gmc = GMC(method="sparseOptFlow", downscale=2) - """ - super().__init__() - - self.method = method - self.downscale = max(1, downscale) - - if self.method == "orb": - self.detector = cv2.FastFeatureDetector_create(20) - self.extractor = cv2.ORB_create() - self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING) - - elif self.method == "sift": - self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) - self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20) - self.matcher = cv2.BFMatcher(cv2.NORM_L2) - - elif self.method == "ecc": - number_of_iterations = 5000 - termination_eps = 1e-6 - self.warp_mode = cv2.MOTION_EUCLIDEAN - self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps) - - elif self.method == "sparseOptFlow": - self.feature_params = dict( - maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04 - ) - - elif self.method in {"none", "None", None}: - self.method = None - else: - raise ValueError(f"Unknown GMC method: {method}") - - self.prevFrame = None - self.prevKeyPoints = None - self.prevDescriptors = None - self.initializedFirstFrame = False - - def apply(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray: - """ - Apply object detection on a raw frame using the specified method. - - Args: - raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). - detections (List, optional): List of detections to be used in the processing. - - Returns: - (np.ndarray): Transformation matrix with shape (2, 3). - - Examples: - >>> gmc = GMC(method="sparseOptFlow") - >>> raw_frame = np.random.rand(480, 640, 3) - >>> transformation_matrix = gmc.apply(raw_frame) - >>> print(transformation_matrix.shape) - (2, 3) - """ - if self.method in {"orb", "sift"}: - return self.apply_features(raw_frame, detections) - elif self.method == "ecc": - return self.apply_ecc(raw_frame) - elif self.method == "sparseOptFlow": - return self.apply_sparseoptflow(raw_frame) - else: - return np.eye(2, 3) - - def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray: - """ - Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation. - - Args: - raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). - - Returns: - (np.ndarray): Transformation matrix with shape (2, 3). - - Examples: - >>> gmc = GMC(method="ecc") - >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) - >>> print(processed_frame) - [[1. 0. 0.] - [0. 1. 0.]] - """ - height, width, c = raw_frame.shape - frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame - H = np.eye(2, 3, dtype=np.float32) - - # Downscale image for computational efficiency - if self.downscale > 1.0: - frame = cv2.GaussianBlur(frame, (3, 3), 1.5) - frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) - - # Handle first frame initialization - if not self.initializedFirstFrame: - self.prevFrame = frame.copy() - self.initializedFirstFrame = True - return H - - # Run the ECC algorithm to find transformation matrix - try: - (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1) - except Exception as e: - LOGGER.warning(f"find transform failed. Set warp as identity {e}") - - return H - - def apply_features(self, raw_frame: np.ndarray, detections: Optional[List] = None) -> np.ndarray: - """ - Apply feature-based methods like ORB or SIFT to a raw frame. - - Args: - raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). - detections (List, optional): List of detections to be used in the processing. - - Returns: - (np.ndarray): Transformation matrix with shape (2, 3). - - Examples: - >>> gmc = GMC(method="orb") - >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) - >>> transformation_matrix = gmc.apply_features(raw_frame) - >>> print(transformation_matrix.shape) - (2, 3) - """ - height, width, c = raw_frame.shape - frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame - H = np.eye(2, 3) - - # Downscale image for computational efficiency - if self.downscale > 1.0: - frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) - width = width // self.downscale - height = height // self.downscale - - # Create mask for keypoint detection, excluding border regions - mask = np.zeros_like(frame) - mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255 - - # Exclude detection regions from mask to avoid tracking detected objects - if detections is not None: - for det in detections: - tlbr = (det[:4] / self.downscale).astype(np.int_) - mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0 - - # Find keypoints and compute descriptors - keypoints = self.detector.detect(frame, mask) - keypoints, descriptors = self.extractor.compute(frame, keypoints) - - # Handle first frame initialization - if not self.initializedFirstFrame: - self.prevFrame = frame.copy() - self.prevKeyPoints = copy.copy(keypoints) - self.prevDescriptors = copy.copy(descriptors) - self.initializedFirstFrame = True - return H - - # Match descriptors between previous and current frame - knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2) - - # Filter matches based on spatial distance constraints - matches = [] - spatialDistances = [] - maxSpatialDistance = 0.25 * np.array([width, height]) - - # Handle empty matches case - if len(knnMatches) == 0: - self.prevFrame = frame.copy() - self.prevKeyPoints = copy.copy(keypoints) - self.prevDescriptors = copy.copy(descriptors) - return H - - # Apply Lowe's ratio test and spatial distance filtering - for m, n in knnMatches: - if m.distance < 0.9 * n.distance: - prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt - currKeyPointLocation = keypoints[m.trainIdx].pt - - spatialDistance = ( - prevKeyPointLocation[0] - currKeyPointLocation[0], - prevKeyPointLocation[1] - currKeyPointLocation[1], - ) - - if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and ( - np.abs(spatialDistance[1]) < maxSpatialDistance[1] - ): - spatialDistances.append(spatialDistance) - matches.append(m) - - # Filter outliers using statistical analysis - meanSpatialDistances = np.mean(spatialDistances, 0) - stdSpatialDistances = np.std(spatialDistances, 0) - inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances - - # Extract good matches and corresponding points - goodMatches = [] - prevPoints = [] - currPoints = [] - for i in range(len(matches)): - if inliers[i, 0] and inliers[i, 1]: - goodMatches.append(matches[i]) - prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt) - currPoints.append(keypoints[matches[i].trainIdx].pt) - - prevPoints = np.array(prevPoints) - currPoints = np.array(currPoints) - - # Estimate transformation matrix using RANSAC - if prevPoints.shape[0] > 4: - H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) - - # Scale translation components back to original resolution - if self.downscale > 1.0: - H[0, 2] *= self.downscale - H[1, 2] *= self.downscale - else: - LOGGER.warning("not enough matching points") - - # Store current frame data for next iteration - self.prevFrame = frame.copy() - self.prevKeyPoints = copy.copy(keypoints) - self.prevDescriptors = copy.copy(descriptors) - - return H - - def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray: - """ - Apply Sparse Optical Flow method to a raw frame. - - Args: - raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C). - - Returns: - (np.ndarray): Transformation matrix with shape (2, 3). - - Examples: - >>> gmc = GMC() - >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])) - >>> print(result) - [[1. 0. 0.] - [0. 1. 0.]] - """ - height, width, c = raw_frame.shape - frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame - H = np.eye(2, 3) - - # Downscale image for computational efficiency - if self.downscale > 1.0: - frame = cv2.resize(frame, (width // self.downscale, height // self.downscale)) - - # Find good features to track - keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params) - - # Handle first frame initialization - if not self.initializedFirstFrame or self.prevKeyPoints is None: - self.prevFrame = frame.copy() - self.prevKeyPoints = copy.copy(keypoints) - self.initializedFirstFrame = True - return H - - # Calculate optical flow using Lucas-Kanade method - matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None) - - # Extract successfully tracked points - prevPoints = [] - currPoints = [] - - for i in range(len(status)): - if status[i]: - prevPoints.append(self.prevKeyPoints[i]) - currPoints.append(matchedKeypoints[i]) - - prevPoints = np.array(prevPoints) - currPoints = np.array(currPoints) - - # Estimate transformation matrix using RANSAC - if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]): - H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC) - - # Scale translation components back to original resolution - if self.downscale > 1.0: - H[0, 2] *= self.downscale - H[1, 2] *= self.downscale - else: - LOGGER.warning("not enough matching points") - - # Store current frame data for next iteration - self.prevFrame = frame.copy() - self.prevKeyPoints = copy.copy(keypoints) - - return H - - def reset_params(self) -> None: - """Reset the internal parameters including previous frame, keypoints, and descriptors.""" - self.prevFrame = None - self.prevKeyPoints = None - self.prevDescriptors = None - self.initializedFirstFrame = False diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/kalman_filter.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/kalman_filter.py deleted file mode 100644 index 82fd515..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/kalman_filter.py +++ /dev/null @@ -1,493 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import numpy as np -import scipy.linalg - - -class KalmanFilterXYAH: - """ - A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter. - - Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space - (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their - respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is - taken as a direct observation of the state space (linear observation model). - - Attributes: - _motion_mat (np.ndarray): The motion matrix for the Kalman filter. - _update_mat (np.ndarray): The update matrix for the Kalman filter. - _std_weight_position (float): Standard deviation weight for position. - _std_weight_velocity (float): Standard deviation weight for velocity. - - Methods: - initiate: Create a track from an unassociated measurement. - predict: Run the Kalman filter prediction step. - project: Project the state distribution to measurement space. - multi_predict: Run the Kalman filter prediction step (vectorized version). - update: Run the Kalman filter correction step. - gating_distance: Compute the gating distance between state distribution and measurements. - - Examples: - Initialize the Kalman filter and create a track from a measurement - >>> kf = KalmanFilterXYAH() - >>> measurement = np.array([100, 200, 1.5, 50]) - >>> mean, covariance = kf.initiate(measurement) - >>> print(mean) - >>> print(covariance) - """ - - def __init__(self): - """ - Initialize Kalman filter model matrices with motion and observation uncertainty weights. - - The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y) - represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective - velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear - observation model for bounding box location. - - Examples: - Initialize a Kalman filter for tracking: - >>> kf = KalmanFilterXYAH() - """ - ndim, dt = 4, 1.0 - - # Create Kalman filter model matrices - self._motion_mat = np.eye(2 * ndim, 2 * ndim) - for i in range(ndim): - self._motion_mat[i, ndim + i] = dt - self._update_mat = np.eye(ndim, 2 * ndim) - - # Motion and observation uncertainty are chosen relative to the current state estimate - self._std_weight_position = 1.0 / 20 - self._std_weight_velocity = 1.0 / 160 - - def initiate(self, measurement: np.ndarray): - """ - Create a track from an unassociated measurement. - - Args: - measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a, - and height h. - - Returns: - mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean. - covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track. - - Examples: - >>> kf = KalmanFilterXYAH() - >>> measurement = np.array([100, 50, 1.5, 200]) - >>> mean, covariance = kf.initiate(measurement) - """ - mean_pos = measurement - mean_vel = np.zeros_like(mean_pos) - mean = np.r_[mean_pos, mean_vel] - - std = [ - 2 * self._std_weight_position * measurement[3], - 2 * self._std_weight_position * measurement[3], - 1e-2, - 2 * self._std_weight_position * measurement[3], - 10 * self._std_weight_velocity * measurement[3], - 10 * self._std_weight_velocity * measurement[3], - 1e-5, - 10 * self._std_weight_velocity * measurement[3], - ] - covariance = np.diag(np.square(std)) - return mean, covariance - - def predict(self, mean: np.ndarray, covariance: np.ndarray): - """ - Run Kalman filter prediction step. - - Args: - mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step. - covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. - - Returns: - mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean. - covariance (np.ndarray): Covariance matrix of the predicted state. - - Examples: - >>> kf = KalmanFilterXYAH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) - """ - std_pos = [ - self._std_weight_position * mean[3], - self._std_weight_position * mean[3], - 1e-2, - self._std_weight_position * mean[3], - ] - std_vel = [ - self._std_weight_velocity * mean[3], - self._std_weight_velocity * mean[3], - 1e-5, - self._std_weight_velocity * mean[3], - ] - motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) - - mean = np.dot(mean, self._motion_mat.T) - covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov - - return mean, covariance - - def project(self, mean: np.ndarray, covariance: np.ndarray): - """ - Project state distribution to measurement space. - - Args: - mean (np.ndarray): The state's mean vector (8 dimensional array). - covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). - - Returns: - mean (np.ndarray): Projected mean of the given state estimate. - covariance (np.ndarray): Projected covariance matrix of the given state estimate. - - Examples: - >>> kf = KalmanFilterXYAH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> projected_mean, projected_covariance = kf.project(mean, covariance) - """ - std = [ - self._std_weight_position * mean[3], - self._std_weight_position * mean[3], - 1e-1, - self._std_weight_position * mean[3], - ] - innovation_cov = np.diag(np.square(std)) - - mean = np.dot(self._update_mat, mean) - covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) - return mean, covariance + innovation_cov - - def multi_predict(self, mean: np.ndarray, covariance: np.ndarray): - """ - Run Kalman filter prediction step for multiple object states (Vectorized version). - - Args: - mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. - covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. - - Returns: - mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8). - covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8). - - Examples: - >>> mean = np.random.rand(10, 8) # 10 object states - >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states - >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance) - """ - std_pos = [ - self._std_weight_position * mean[:, 3], - self._std_weight_position * mean[:, 3], - 1e-2 * np.ones_like(mean[:, 3]), - self._std_weight_position * mean[:, 3], - ] - std_vel = [ - self._std_weight_velocity * mean[:, 3], - self._std_weight_velocity * mean[:, 3], - 1e-5 * np.ones_like(mean[:, 3]), - self._std_weight_velocity * mean[:, 3], - ] - sqr = np.square(np.r_[std_pos, std_vel]).T - - motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] - motion_cov = np.asarray(motion_cov) - - mean = np.dot(mean, self._motion_mat.T) - left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) - covariance = np.dot(left, self._motion_mat.T) + motion_cov - - return mean, covariance - - def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray): - """ - Run Kalman filter correction step. - - Args: - mean (np.ndarray): The predicted state's mean vector (8 dimensional). - covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). - measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center - position, a the aspect ratio, and h the height of the bounding box. - - Returns: - new_mean (np.ndarray): Measurement-corrected state mean. - new_covariance (np.ndarray): Measurement-corrected state covariance. - - Examples: - >>> kf = KalmanFilterXYAH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> measurement = np.array([1, 1, 1, 1]) - >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) - """ - projected_mean, projected_cov = self.project(mean, covariance) - - chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False) - kalman_gain = scipy.linalg.cho_solve( - (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False - ).T - innovation = measurement - projected_mean - - new_mean = mean + np.dot(innovation, kalman_gain.T) - new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T)) - return new_mean, new_covariance - - def gating_distance( - self, - mean: np.ndarray, - covariance: np.ndarray, - measurements: np.ndarray, - only_position: bool = False, - metric: str = "maha", - ) -> np.ndarray: - """ - Compute gating distance between state distribution and measurements. - - A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square - distribution has 4 degrees of freedom, otherwise 2. - - Args: - mean (np.ndarray): Mean vector over the state distribution (8 dimensional). - covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional). - measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the - bounding box center position, a the aspect ratio, and h the height. - only_position (bool, optional): If True, distance computation is done with respect to box center position only. - metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared - Euclidean distance and 'maha' for the squared Mahalanobis distance. - - Returns: - (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between - (mean, covariance) and `measurements[i]`. - - Examples: - Compute gating distance using Mahalanobis metric: - >>> kf = KalmanFilterXYAH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]]) - >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha") - """ - mean, covariance = self.project(mean, covariance) - if only_position: - mean, covariance = mean[:2], covariance[:2, :2] - measurements = measurements[:, :2] - - d = measurements - mean - if metric == "gaussian": - return np.sum(d * d, axis=1) - elif metric == "maha": - cholesky_factor = np.linalg.cholesky(covariance) - z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True) - return np.sum(z * z, axis=0) # square maha - else: - raise ValueError("Invalid distance metric") - - -class KalmanFilterXYWH(KalmanFilterXYAH): - """ - A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter. - - Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where - (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities. - The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct - observation of the state space (linear observation model). - - Attributes: - _motion_mat (np.ndarray): The motion matrix for the Kalman filter. - _update_mat (np.ndarray): The update matrix for the Kalman filter. - _std_weight_position (float): Standard deviation weight for position. - _std_weight_velocity (float): Standard deviation weight for velocity. - - Methods: - initiate: Create a track from an unassociated measurement. - predict: Run the Kalman filter prediction step. - project: Project the state distribution to measurement space. - multi_predict: Run the Kalman filter prediction step in a vectorized manner. - update: Run the Kalman filter correction step. - - Examples: - Create a Kalman filter and initialize a track - >>> kf = KalmanFilterXYWH() - >>> measurement = np.array([100, 50, 20, 40]) - >>> mean, covariance = kf.initiate(measurement) - >>> print(mean) - >>> print(covariance) - """ - - def initiate(self, measurement: np.ndarray): - """ - Create track from unassociated measurement. - - Args: - measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height. - - Returns: - mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean. - covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track. - - Examples: - >>> kf = KalmanFilterXYWH() - >>> measurement = np.array([100, 50, 20, 40]) - >>> mean, covariance = kf.initiate(measurement) - >>> print(mean) - [100. 50. 20. 40. 0. 0. 0. 0.] - >>> print(covariance) - [[ 4. 0. 0. 0. 0. 0. 0. 0.] - [ 0. 4. 0. 0. 0. 0. 0. 0.] - [ 0. 0. 4. 0. 0. 0. 0. 0.] - [ 0. 0. 0. 4. 0. 0. 0. 0.] - [ 0. 0. 0. 0. 0.25 0. 0. 0.] - [ 0. 0. 0. 0. 0. 0.25 0. 0.] - [ 0. 0. 0. 0. 0. 0. 0.25 0.] - [ 0. 0. 0. 0. 0. 0. 0. 0.25]] - """ - mean_pos = measurement - mean_vel = np.zeros_like(mean_pos) - mean = np.r_[mean_pos, mean_vel] - - std = [ - 2 * self._std_weight_position * measurement[2], - 2 * self._std_weight_position * measurement[3], - 2 * self._std_weight_position * measurement[2], - 2 * self._std_weight_position * measurement[3], - 10 * self._std_weight_velocity * measurement[2], - 10 * self._std_weight_velocity * measurement[3], - 10 * self._std_weight_velocity * measurement[2], - 10 * self._std_weight_velocity * measurement[3], - ] - covariance = np.diag(np.square(std)) - return mean, covariance - - def predict(self, mean: np.ndarray, covariance: np.ndarray): - """ - Run Kalman filter prediction step. - - Args: - mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step. - covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step. - - Returns: - mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean. - covariance (np.ndarray): Covariance matrix of the predicted state. - - Examples: - >>> kf = KalmanFilterXYWH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance) - """ - std_pos = [ - self._std_weight_position * mean[2], - self._std_weight_position * mean[3], - self._std_weight_position * mean[2], - self._std_weight_position * mean[3], - ] - std_vel = [ - self._std_weight_velocity * mean[2], - self._std_weight_velocity * mean[3], - self._std_weight_velocity * mean[2], - self._std_weight_velocity * mean[3], - ] - motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) - - mean = np.dot(mean, self._motion_mat.T) - covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov - - return mean, covariance - - def project(self, mean: np.ndarray, covariance: np.ndarray): - """ - Project state distribution to measurement space. - - Args: - mean (np.ndarray): The state's mean vector (8 dimensional array). - covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). - - Returns: - mean (np.ndarray): Projected mean of the given state estimate. - covariance (np.ndarray): Projected covariance matrix of the given state estimate. - - Examples: - >>> kf = KalmanFilterXYWH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> projected_mean, projected_cov = kf.project(mean, covariance) - """ - std = [ - self._std_weight_position * mean[2], - self._std_weight_position * mean[3], - self._std_weight_position * mean[2], - self._std_weight_position * mean[3], - ] - innovation_cov = np.diag(np.square(std)) - - mean = np.dot(self._update_mat, mean) - covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) - return mean, covariance + innovation_cov - - def multi_predict(self, mean: np.ndarray, covariance: np.ndarray): - """ - Run Kalman filter prediction step (Vectorized version). - - Args: - mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step. - covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step. - - Returns: - mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8). - covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8). - - Examples: - >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors - >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices - >>> kf = KalmanFilterXYWH() - >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance) - """ - std_pos = [ - self._std_weight_position * mean[:, 2], - self._std_weight_position * mean[:, 3], - self._std_weight_position * mean[:, 2], - self._std_weight_position * mean[:, 3], - ] - std_vel = [ - self._std_weight_velocity * mean[:, 2], - self._std_weight_velocity * mean[:, 3], - self._std_weight_velocity * mean[:, 2], - self._std_weight_velocity * mean[:, 3], - ] - sqr = np.square(np.r_[std_pos, std_vel]).T - - motion_cov = [np.diag(sqr[i]) for i in range(len(mean))] - motion_cov = np.asarray(motion_cov) - - mean = np.dot(mean, self._motion_mat.T) - left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2)) - covariance = np.dot(left, self._motion_mat.T) + motion_cov - - return mean, covariance - - def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray): - """ - Run Kalman filter correction step. - - Args: - mean (np.ndarray): The predicted state's mean vector (8 dimensional). - covariance (np.ndarray): The state's covariance matrix (8x8 dimensional). - measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center - position, w the width, and h the height of the bounding box. - - Returns: - new_mean (np.ndarray): Measurement-corrected state mean. - new_covariance (np.ndarray): Measurement-corrected state covariance. - - Examples: - >>> kf = KalmanFilterXYWH() - >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0]) - >>> covariance = np.eye(8) - >>> measurement = np.array([0.5, 0.5, 1.2, 1.2]) - >>> new_mean, new_covariance = kf.update(mean, covariance, measurement) - """ - return super().update(mean, covariance, measurement) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/matching.py b/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/matching.py deleted file mode 100644 index 4e74197..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/trackers/utils/matching.py +++ /dev/null @@ -1,157 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import numpy as np -import scipy -from scipy.spatial.distance import cdist - -from ultralytics.utils.metrics import batch_probiou, bbox_ioa - -try: - import lap # for linear_assignment - - assert lap.__version__ # verify package is not directory -except (ImportError, AssertionError, AttributeError): - from ultralytics.utils.checks import check_requirements - - check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap - import lap - - -def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True): - """ - Perform linear assignment using either the scipy or lap.lapjv method. - - Args: - cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). - thresh (float): Threshold for considering an assignment valid. - use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used. - - Returns: - matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches. - unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,). - unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,). - - Examples: - >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - >>> thresh = 5.0 - >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True) - """ - if cost_matrix.size == 0: - return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1])) - - if use_lap: - # Use lap.lapjv - # https://github.com/gatagat/lap - _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh) - matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0] - unmatched_a = np.where(x < 0)[0] - unmatched_b = np.where(y < 0)[0] - else: - # Use scipy.optimize.linear_sum_assignment - # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html - x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y - matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh]) - if len(matches) == 0: - unmatched_a = list(np.arange(cost_matrix.shape[0])) - unmatched_b = list(np.arange(cost_matrix.shape[1])) - else: - unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0])) - unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1])) - - return matches, unmatched_a, unmatched_b - - -def iou_distance(atracks: list, btracks: list) -> np.ndarray: - """ - Compute cost based on Intersection over Union (IoU) between tracks. - - Args: - atracks (List[STrack] | List[np.ndarray]): List of tracks 'a' or bounding boxes. - btracks (List[STrack] | List[np.ndarray]): List of tracks 'b' or bounding boxes. - - Returns: - (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)). - - Examples: - Compute IoU distance between two sets of tracks - >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])] - >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])] - >>> cost_matrix = iou_distance(atracks, btracks) - """ - if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray): - atlbrs = atracks - btlbrs = btracks - else: - atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks] - btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks] - - ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32) - if len(atlbrs) and len(btlbrs): - if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5: - ious = batch_probiou( - np.ascontiguousarray(atlbrs, dtype=np.float32), - np.ascontiguousarray(btlbrs, dtype=np.float32), - ).numpy() - else: - ious = bbox_ioa( - np.ascontiguousarray(atlbrs, dtype=np.float32), - np.ascontiguousarray(btlbrs, dtype=np.float32), - iou=True, - ) - return 1 - ious # cost matrix - - -def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray: - """ - Compute distance between tracks and detections based on embeddings. - - Args: - tracks (List[STrack]): List of tracks, where each track contains embedding features. - detections (List[BaseTrack]): List of detections, where each detection contains embedding features. - metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc. - - Returns: - (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks - and M is the number of detections. - - Examples: - Compute the embedding distance between tracks and detections using cosine metric - >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features - >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features - >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine") - """ - cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32) - if cost_matrix.size == 0: - return cost_matrix - det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32) - # for i, track in enumerate(tracks): - # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric)) - track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32) - cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features - return cost_matrix - - -def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray: - """ - Fuse cost matrix with detection scores to produce a single similarity matrix. - - Args: - cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M). - detections (List[BaseTrack]): List of detections, each containing a score attribute. - - Returns: - (np.ndarray): Fused similarity matrix with shape (N, M). - - Examples: - Fuse a cost matrix with detection scores - >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections - >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)] - >>> fused_matrix = fuse_score(cost_matrix, detections) - """ - if cost_matrix.size == 0: - return cost_matrix - iou_sim = 1 - cost_matrix - det_scores = np.array([det.score for det in detections]) - det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0) - fuse_sim = iou_sim * det_scores - return 1 - fuse_sim # fuse_cost diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/__init__.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/__init__.py deleted file mode 100644 index 948910f..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/__init__.py +++ /dev/null @@ -1,1599 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import contextlib -import importlib.metadata -import inspect -import json -import logging -import os -import platform -import re -import subprocess -import sys -import threading -import time -import warnings -from pathlib import Path -from threading import Lock -from types import SimpleNamespace -from typing import Union -from urllib.parse import unquote - -import cv2 -import numpy as np -import torch -import tqdm - -from ultralytics import __version__ -from ultralytics.utils.patches import imread, imshow, imwrite, torch_save # for patches - -# PyTorch Multi-GPU DDP Constants -RANK = int(os.getenv("RANK", -1)) -LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html - -# Other Constants -ARGV = sys.argv or ["", ""] # sometimes sys.argv = [] -FILE = Path(__file__).resolve() -ROOT = FILE.parents[1] # YOLO -ASSETS = ROOT / "assets" # default images -ASSETS_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0" # assets GitHub URL -DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml" -NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads -AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode -VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode -TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format -LOGGING_NAME = "ultralytics" -MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans -MACOS_VERSION = platform.mac_ver()[0] if MACOS else None -ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans -PYTHON_VERSION = platform.python_version() -TORCH_VERSION = torch.__version__ -TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision -IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode" -RKNN_CHIPS = frozenset( - { - "rk3588", - "rk3576", - "rk3566", - "rk3568", - "rk3562", - "rv1103", - "rv1106", - "rv1103b", - "rv1106b", - "rk2118", - } -) # Rockchip processors available for export -HELP_MSG = """ - Examples for running Ultralytics: - - 1. Install the ultralytics package: - - pip install ultralytics - - 2. Use the Python SDK: - - from ultralytics import YOLO - - # Load a model - model = YOLO("yolo11n.yaml") # build a new model from scratch - model = YOLO("yolo11n.pt") # load a pretrained model (recommended for training) - - # Use the model - results = model.train(data="coco8.yaml", epochs=3) # train the model - results = model.val() # evaluate model performance on the validation set - results = model("https://ultralytics.com/images/bus.jpg") # predict on an image - success = model.export(format="onnx") # export the model to ONNX format - - 3. Use the command line interface (CLI): - - Ultralytics 'yolo' CLI commands use the following syntax: - - yolo TASK MODE ARGS - - Where TASK (optional) is one of [detect, segment, classify, pose, obb] - MODE (required) is one of [train, val, predict, export, track, benchmark] - ARGS (optional) are any number of custom "arg=value" pairs like "imgsz=320" that override defaults. - See all ARGS at https://docs.ultralytics.com/usage/cfg or with "yolo cfg" - - - Train a detection model for 10 epochs with an initial learning_rate of 0.01 - yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01 - - - Predict a YouTube video using a pretrained segmentation model at image size 320: - yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320 - - - Val a pretrained detection model at batch-size 1 and image size 640: - yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640 - - - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required) - yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128 - - - Run special commands: - yolo help - yolo checks - yolo version - yolo settings - yolo copy-cfg - yolo cfg - - Docs: https://docs.ultralytics.com - Community: https://community.ultralytics.com - GitHub: https://github.com/ultralytics/ultralytics - """ - -# Settings and Environment Variables -torch.set_printoptions(linewidth=320, precision=4, profile="default") -np.set_printoptions(linewidth=320, formatter=dict(float_kind="{:11.5g}".format)) # format short g, %precision=5 -cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader) -os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warnings in Colab -os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings -os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs - -if TQDM_RICH := str(os.getenv("YOLO_TQDM_RICH", False)).lower() == "true": - from tqdm import rich - - -class TQDM(rich.tqdm if TQDM_RICH else tqdm.tqdm): - """ - A custom TQDM progress bar class that extends the original tqdm functionality. - - This class modifies the behavior of the original tqdm progress bar based on global settings and provides - additional customization options for Ultralytics projects. The progress bar is automatically disabled when - VERBOSE is False or when explicitly disabled. - - Attributes: - disable (bool): Whether to disable the progress bar. Determined by the global VERBOSE setting and - any passed 'disable' argument. - bar_format (str): The format string for the progress bar. Uses the global TQDM_BAR_FORMAT if not - explicitly set. - - Methods: - __init__: Initialize the TQDM object with custom settings. - __iter__: Return self as iterator to satisfy Iterable interface. - - Examples: - >>> from ultralytics.utils import TQDM - >>> for i in TQDM(range(100)): - ... # Your processing code here - ... pass - """ - - def __init__(self, *args, **kwargs): - """ - Initialize a custom TQDM progress bar with Ultralytics-specific settings. - - Args: - *args (Any): Variable length argument list to be passed to the original tqdm constructor. - **kwargs (Any): Arbitrary keyword arguments to be passed to the original tqdm constructor. - - Notes: - - The progress bar is disabled if VERBOSE is False or if 'disable' is explicitly set to True in kwargs. - - The default bar format is set to TQDM_BAR_FORMAT unless overridden in kwargs. - - Examples: - >>> from ultralytics.utils import TQDM - >>> for i in TQDM(range(100)): - ... # Your code here - ... pass - """ - warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning) # suppress tqdm.rich warning - kwargs["disable"] = not VERBOSE or kwargs.get("disable", False) or LOGGER.getEffectiveLevel() > 20 - kwargs.setdefault("bar_format", TQDM_BAR_FORMAT) # override default value if passed - super().__init__(*args, **kwargs) - - def __iter__(self): - """Return self as iterator to satisfy Iterable interface.""" - return super().__iter__() - - -class DataExportMixin: - """ - Mixin class for exporting validation metrics or prediction results in various formats. - - This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results - from classification, object detection, segmentation, or pose estimation tasks into various formats: Pandas - DataFrame, CSV, XML, HTML, JSON and SQLite (SQL). - - Methods: - to_df: Convert summary to a Pandas DataFrame. - to_csv: Export results as a CSV string. - to_xml: Export results as an XML string (requires `lxml`). - to_html: Export results as an HTML table. - to_json: Export results as a JSON string. - tojson: Deprecated alias for `to_json()`. - to_sql: Export results to an SQLite database. - - Examples: - >>> model = YOLO("yolo11n.pt") - >>> results = model("image.jpg") - >>> df = results.to_df() - >>> print(df) - >>> csv_data = results.to_csv() - >>> results.to_sql(table_name="yolo_results") - """ - - def to_df(self, normalize=False, decimals=5): - """ - Create a pandas DataFrame from the prediction results summary or validation metrics. - - Args: - normalize (bool, optional): Normalize numerical values for easier comparison. - decimals (int, optional): Decimal places to round floats. - - Returns: - (DataFrame): DataFrame containing the summary data. - """ - import pandas as pd # scope for faster 'import ultralytics' - - return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals)) - - def to_csv(self, normalize=False, decimals=5): - """ - Export results to CSV string format. - - Args: - normalize (bool, optional): Normalize numeric values. - decimals (int, optional): Decimal precision. - - Returns: - (str): CSV content as string. - """ - return self.to_df(normalize=normalize, decimals=decimals).to_csv() - - def to_xml(self, normalize=False, decimals=5): - """ - Export results to XML format. - - Args: - normalize (bool, optional): Normalize numeric values. - decimals (int, optional): Decimal precision. - - Returns: - (str): XML string. - - Notes: - Requires `lxml` package to be installed. - """ - df = self.to_df(normalize=normalize, decimals=decimals) - return '\n' if df.empty else df.to_xml(parser="etree") - - def to_html(self, normalize=False, decimals=5, index=False): - """ - Export results to HTML table format. - - Args: - normalize (bool, optional): Normalize numeric values. - decimals (int, optional): Decimal precision. - index (bool, optional): Whether to include index column in the HTML table. - - Returns: - (str): HTML representation of the results. - """ - df = self.to_df(normalize=normalize, decimals=decimals) - return "
" if df.empty else df.to_html(index=index) - - def tojson(self, normalize=False, decimals=5): - """Deprecated version of to_json().""" - LOGGER.warning("'result.tojson()' is deprecated, replace with 'result.to_json()'.") - return self.to_json(normalize, decimals) - - def to_json(self, normalize=False, decimals=5): - """ - Export results to JSON format. - - Args: - normalize (bool, optional): Normalize numeric values. - decimals (int, optional): Decimal precision. - - Returns: - (str): JSON-formatted string of the results. - """ - return self.to_df(normalize=normalize, decimals=decimals).to_json(orient="records", indent=2) - - def to_sql(self, normalize=False, decimals=5, table_name="results", db_path="results.db"): - """ - Save results to an SQLite database. - - Args: - normalize (bool, optional): Normalize numeric values. - decimals (int, optional): Decimal precision. - table_name (str, optional): Name of the SQL table. - db_path (str, optional): SQLite database file path. - """ - df = self.to_df(normalize, decimals) - if df.empty or df.columns.empty: # Exit if df is None or has no columns (i.e., no schema) - return - - import sqlite3 - - conn = sqlite3.connect(db_path) - cursor = conn.cursor() - - # Dynamically create table schema based on summary to support prediction and validation results export - columns = [] - for col in df.columns: - sample_val = df[col].dropna().iloc[0] if not df[col].dropna().empty else "" - if isinstance(sample_val, dict): - col_type = "TEXT" - elif isinstance(sample_val, (float, int)): - col_type = "REAL" - else: - col_type = "TEXT" - columns.append(f'"{col}" {col_type}') # Quote column names to handle special characters like hyphens - - # Create table (Drop table from db if it's already exist) - cursor.execute(f'DROP TABLE IF EXISTS "{table_name}"') - cursor.execute(f'CREATE TABLE "{table_name}" (id INTEGER PRIMARY KEY AUTOINCREMENT, {", ".join(columns)})') - - for _, row in df.iterrows(): - values = [json.dumps(v) if isinstance(v, dict) else v for v in row] - column_names = ", ".join(f'"{col}"' for col in df.columns) - placeholders = ", ".join("?" for _ in df.columns) - cursor.execute(f'INSERT INTO "{table_name}" ({column_names}) VALUES ({placeholders})', values) - - conn.commit() - conn.close() - LOGGER.info(f"Results saved to SQL table '{table_name}' in '{db_path}'.") - - -class SimpleClass: - """ - A simple base class for creating objects with string representations of their attributes. - - This class provides a foundation for creating objects that can be easily printed or represented as strings, - showing all their non-callable attributes. It's useful for debugging and introspection of object states. - - Methods: - __str__: Return a human-readable string representation of the object. - __repr__: Return a machine-readable string representation of the object. - __getattr__: Provide a custom attribute access error message with helpful information. - - Examples: - >>> class MyClass(SimpleClass): - ... def __init__(self): - ... self.x = 10 - ... self.y = "hello" - >>> obj = MyClass() - >>> print(obj) - __main__.MyClass object with attributes: - - x: 10 - y: 'hello' - - Notes: - - This class is designed to be subclassed. It provides a convenient way to inspect object attributes. - - The string representation includes the module and class name of the object. - - Callable attributes and attributes starting with an underscore are excluded from the string representation. - """ - - def __str__(self): - """Return a human-readable string representation of the object.""" - attr = [] - for a in dir(self): - v = getattr(self, a) - if not callable(v) and not a.startswith("_"): - if isinstance(v, SimpleClass): - # Display only the module and class name for subclasses - s = f"{a}: {v.__module__}.{v.__class__.__name__} object" - else: - s = f"{a}: {repr(v)}" - attr.append(s) - return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr) - - def __repr__(self): - """Return a machine-readable string representation of the object.""" - return self.__str__() - - def __getattr__(self, attr): - """Provide a custom attribute access error message with helpful information.""" - name = self.__class__.__name__ - raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") - - -class IterableSimpleNamespace(SimpleNamespace): - """ - An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration. - - This class extends the SimpleNamespace class with additional methods for iteration, string representation, - and attribute access. It is designed to be used as a convenient container for storing and accessing - configuration parameters. - - Methods: - __iter__: Return an iterator of key-value pairs from the namespace's attributes. - __str__: Return a human-readable string representation of the object. - __getattr__: Provide a custom attribute access error message with helpful information. - get: Retrieve the value of a specified key, or a default value if the key doesn't exist. - - Examples: - >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3) - >>> for k, v in cfg: - ... print(f"{k}: {v}") - a: 1 - b: 2 - c: 3 - >>> print(cfg) - a=1 - b=2 - c=3 - >>> cfg.get("b") - 2 - >>> cfg.get("d", "default") - 'default' - - Notes: - This class is particularly useful for storing configuration parameters in a more accessible - and iterable format compared to a standard dictionary. - """ - - def __iter__(self): - """Return an iterator of key-value pairs from the namespace's attributes.""" - return iter(vars(self).items()) - - def __str__(self): - """Return a human-readable string representation of the object.""" - return "\n".join(f"{k}={v}" for k, v in vars(self).items()) - - def __getattr__(self, attr): - """Provide a custom attribute access error message with helpful information.""" - name = self.__class__.__name__ - raise AttributeError( - f""" - '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics - 'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace - {DEFAULT_CFG_PATH} with the latest version from - https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml - """ - ) - - def get(self, key, default=None): - """Return the value of the specified key if it exists; otherwise, return the default value.""" - return getattr(self, key, default) - - -def plt_settings(rcparams=None, backend="Agg"): - """ - Decorator to temporarily set rc parameters and the backend for a plotting function. - - Args: - rcparams (dict, optional): Dictionary of rc parameters to set. - backend (str, optional): Name of the backend to use. - - Returns: - (Callable): Decorated function with temporarily set rc parameters and backend. - - Examples: - >>> @plt_settings({"font.size": 12}) - >>> def plot_function(): - ... plt.figure() - ... plt.plot([1, 2, 3]) - ... plt.show() - - >>> with plt_settings({"font.size": 12}): - ... plt.figure() - ... plt.plot([1, 2, 3]) - ... plt.show() - """ - if rcparams is None: - rcparams = {"font.size": 11} - - def decorator(func): - """Decorator to apply temporary rc parameters and backend to a function.""" - - def wrapper(*args, **kwargs): - """Set rc parameters and backend, call the original function, and restore the settings.""" - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - - original_backend = plt.get_backend() - switch = backend.lower() != original_backend.lower() - if switch: - plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8 - plt.switch_backend(backend) - - # Plot with backend and always revert to original backend - try: - with plt.rc_context(rcparams): - result = func(*args, **kwargs) - finally: - if switch: - plt.close("all") - plt.switch_backend(original_backend) - return result - - return wrapper - - return decorator - - -def set_logging(name="LOGGING_NAME", verbose=True): - """ - Set up logging with UTF-8 encoding and configurable verbosity. - - This function configures logging for the Ultralytics library, setting the appropriate logging level and - formatter based on the verbosity flag and the current process rank. It handles special cases for Windows - environments where UTF-8 encoding might not be the default. - - Args: - name (str): Name of the logger. - verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise. - - Returns: - (logging.Logger): Configured logger object. - - Examples: - >>> set_logging(name="ultralytics", verbose=True) - >>> logger = logging.getLogger("ultralytics") - >>> logger.info("This is an info message") - - Notes: - - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible. - - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments. - - The function sets up a StreamHandler with the appropriate formatter and level. - - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers. - """ - level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings - - class PrefixFormatter(logging.Formatter): - def format(self, record): - """Format log records with prefixes based on level.""" - # Apply prefixes based on log level - if record.levelno == logging.WARNING: - prefix = "WARNING ⚠️" if not WINDOWS else "WARNING" - record.msg = f"{prefix} {record.msg}" - elif record.levelno == logging.ERROR: - prefix = "ERROR ❌" if not WINDOWS else "ERROR" - record.msg = f"{prefix} {record.msg}" - - # Handle emojis in message based on platform - formatted_message = super().format(record) - return emojis(formatted_message) - - formatter = PrefixFormatter("%(message)s") - - # Handle Windows UTF-8 encoding issues - if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8": - try: - # Attempt to reconfigure stdout to use UTF-8 encoding if possible - if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8") - # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper - elif hasattr(sys.stdout, "buffer"): - import io - - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") - except Exception: - pass - - # Create and configure the StreamHandler with the appropriate formatter and level - stream_handler = logging.StreamHandler(sys.stdout) - stream_handler.setFormatter(formatter) - stream_handler.setLevel(level) - - # Set up the logger - logger = logging.getLogger(name) - logger.setLevel(level) - logger.addHandler(stream_handler) - logger.propagate = False - return logger - - -# Set logger -LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.) -for logger in "sentry_sdk", "urllib3.connectionpool": - logging.getLogger(logger).setLevel(logging.CRITICAL + 1) - - -def emojis(string=""): - """Return platform-dependent emoji-safe version of string.""" - return string.encode().decode("ascii", "ignore") if WINDOWS else string - - -class ThreadingLocked: - """ - A decorator class for ensuring thread-safe execution of a function or method. - - This class can be used as a decorator to make sure that if the decorated function is called from multiple threads, - only one thread at a time will be able to execute the function. - - Attributes: - lock (threading.Lock): A lock object used to manage access to the decorated function. - - Examples: - >>> from ultralytics.utils import ThreadingLocked - >>> @ThreadingLocked() - >>> def my_function(): - ... # Your code here - """ - - def __init__(self): - """Initialize the decorator class with a threading lock.""" - self.lock = threading.Lock() - - def __call__(self, f): - """Run thread-safe execution of function or method.""" - from functools import wraps - - @wraps(f) - def decorated(*args, **kwargs): - """Apply thread-safety to the decorated function or method.""" - with self.lock: - return f(*args, **kwargs) - - return decorated - - -class YAML: - """ - YAML utility class for efficient file operations with automatic C-implementation detection. - - This class provides optimized YAML loading and saving operations using PyYAML's fastest available implementation - (C-based when possible). It implements a singleton pattern with lazy initialization, allowing direct class method - usage without explicit instantiation. The class handles file path creation, validation, and character encoding - issues automatically. - - The implementation prioritizes performance through: - - Automatic C-based loader/dumper selection when available - - Singleton pattern to reuse the same instance - - Lazy initialization to defer import costs until needed - - Fallback mechanisms for handling problematic YAML content - - Attributes: - _instance: Internal singleton instance storage. - yaml: Reference to the PyYAML module. - SafeLoader: Best available YAML loader (CSafeLoader if available). - SafeDumper: Best available YAML dumper (CSafeDumper if available). - - Examples: - >>> data = YAML.load("config.yaml") - >>> data["new_value"] = 123 - >>> YAML.save("updated_config.yaml", data) - >>> YAML.print(data) - """ - - _instance = None - - @classmethod - def _get_instance(cls): - """Initialize singleton instance on first use.""" - if cls._instance is None: - cls._instance = cls() - return cls._instance - - def __init__(self): - """Initialize with optimal YAML implementation (C-based when available).""" - import yaml - - self.yaml = yaml - # Use C-based implementation if available for better performance - try: - self.SafeLoader = yaml.CSafeLoader - self.SafeDumper = yaml.CSafeDumper - except (AttributeError, ImportError): - self.SafeLoader = yaml.SafeLoader - self.SafeDumper = yaml.SafeDumper - - @classmethod - def save(cls, file="data.yaml", data=None, header=""): - """ - Save Python object as YAML file. - - Args: - file (str | Path): Path to save YAML file. - data (dict | None): Dict or compatible object to save. - header (str): Optional string to add at file beginning. - """ - instance = cls._get_instance() - if data is None: - data = {} - - # Create parent directories if needed - file = Path(file) - file.parent.mkdir(parents=True, exist_ok=True) - - # Convert non-serializable objects to strings - valid_types = int, float, str, bool, list, tuple, dict, type(None) - for k, v in data.items(): - if not isinstance(v, valid_types): - data[k] = str(v) - - # Write YAML file - with open(file, "w", errors="ignore", encoding="utf-8") as f: - if header: - f.write(header) - instance.yaml.dump(data, f, sort_keys=False, allow_unicode=True, Dumper=instance.SafeDumper) - - @classmethod - def load(cls, file="data.yaml", append_filename=False): - """ - Load YAML file to Python object with robust error handling. - - Args: - file (str | Path): Path to YAML file. - append_filename (bool): Whether to add filename to returned dict. - - Returns: - (dict): Loaded YAML content. - """ - instance = cls._get_instance() - assert str(file).endswith((".yaml", ".yml")), f"Not a YAML file: {file}" - - # Read file content - with open(file, errors="ignore", encoding="utf-8") as f: - s = f.read() - - # Try loading YAML with fallback for problematic characters - try: - data = instance.yaml.load(s, Loader=instance.SafeLoader) or {} - except Exception: - # Remove problematic characters and retry - s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s) - data = instance.yaml.load(s, Loader=instance.SafeLoader) or {} - - # Check for accidental user-error None strings (should be 'null' in YAML) - if "None" in data.values(): - data = {k: None if v == "None" else v for k, v in data.items()} - - if append_filename: - data["yaml_file"] = str(file) - return data - - @classmethod - def print(cls, yaml_file): - """ - Pretty print YAML file or object to console. - - Args: - yaml_file (str | Path | dict): Path to YAML file or dict to print. - """ - instance = cls._get_instance() - - # Load file if path provided - yaml_dict = cls.load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file - - # Use -1 for unlimited width in C implementation - dump = instance.yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=-1, Dumper=instance.SafeDumper) - - LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}") - - -# Default configuration -DEFAULT_CFG_DICT = YAML.load(DEFAULT_CFG_PATH) -DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys() -DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT) - - -def read_device_model() -> str: - """ - Read the device model information from the system and cache it for quick access. - - Returns: - (str): Kernel release information. - """ - return platform.release().lower() - - -def is_ubuntu() -> bool: - """ - Check if the OS is Ubuntu. - - Returns: - (bool): True if OS is Ubuntu, False otherwise. - """ - try: - with open("/etc/os-release") as f: - return "ID=ubuntu" in f.read() - except FileNotFoundError: - return False - - -def is_colab(): - """ - Check if the current script is running inside a Google Colab notebook. - - Returns: - (bool): True if running inside a Colab notebook, False otherwise. - """ - return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ - - -def is_kaggle(): - """ - Check if the current script is running inside a Kaggle kernel. - - Returns: - (bool): True if running inside a Kaggle kernel, False otherwise. - """ - return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com" - - -def is_jupyter(): - """ - Check if the current script is running inside a Jupyter Notebook. - - Returns: - (bool): True if running inside a Jupyter Notebook, False otherwise. - - Notes: - - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable. - - "get_ipython" in globals() method suffers false positives when IPython package installed manually. - """ - return IS_COLAB or IS_KAGGLE - - -def is_runpod(): - """ - Check if the current script is running inside a RunPod container. - - Returns: - (bool): True if running in RunPod, False otherwise. - """ - return "RUNPOD_POD_ID" in os.environ - - -def is_docker() -> bool: - """ - Determine if the script is running inside a Docker container. - - Returns: - (bool): True if the script is running inside a Docker container, False otherwise. - """ - try: - return os.path.exists("/.dockerenv") - except Exception: - return False - - -def is_raspberrypi() -> bool: - """ - Determine if the Python environment is running on a Raspberry Pi. - - Returns: - (bool): True if running on a Raspberry Pi, False otherwise. - """ - return "rpi" in DEVICE_MODEL - - -def is_jetson() -> bool: - """ - Determine if the Python environment is running on an NVIDIA Jetson device. - - Returns: - (bool): True if running on an NVIDIA Jetson device, False otherwise. - """ - return "tegra" in DEVICE_MODEL - - -def is_online() -> bool: - """ - Check internet connectivity by attempting to connect to a known online host. - - Returns: - (bool): True if connection is successful, False otherwise. - """ - try: - assert str(os.getenv("YOLO_OFFLINE", "")).lower() != "true" # check if ENV var YOLO_OFFLINE="True" - import socket - - for dns in ("1.1.1.1", "8.8.8.8"): # check Cloudflare and Google DNS - socket.create_connection(address=(dns, 80), timeout=2.0).close() - return True - except Exception: - return False - - -def is_pip_package(filepath: str = __name__) -> bool: - """ - Determine if the file at the given filepath is part of a pip package. - - Args: - filepath (str): The filepath to check. - - Returns: - (bool): True if the file is part of a pip package, False otherwise. - """ - import importlib.util - - # Get the spec for the module - spec = importlib.util.find_spec(filepath) - - # Return whether the spec is not None and the origin is not None (indicating it is a package) - return spec is not None and spec.origin is not None - - -def is_dir_writeable(dir_path: Union[str, Path]) -> bool: - """ - Check if a directory is writeable. - - Args: - dir_path (str | Path): The path to the directory. - - Returns: - (bool): True if the directory is writeable, False otherwise. - """ - return os.access(str(dir_path), os.W_OK) - - -def is_pytest_running(): - """ - Determine whether pytest is currently running or not. - - Returns: - (bool): True if pytest is running, False otherwise. - """ - return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem) - - -def is_github_action_running() -> bool: - """ - Determine if the current environment is a GitHub Actions runner. - - Returns: - (bool): True if the current environment is a GitHub Actions runner, False otherwise. - """ - return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ - - -def get_git_dir(): - """ - Determine whether the current file is part of a git repository and if so, return the repository root directory. - - Returns: - (Path | None): Git root directory if found or None if not found. - """ - for d in Path(__file__).parents: - if (d / ".git").is_dir(): - return d - - -def is_git_dir(): - """ - Determine whether the current file is part of a git repository. - - Returns: - (bool): True if current file is part of a git repository. - """ - return GIT_DIR is not None - - -def get_git_origin_url(): - """ - Retrieve the origin URL of a git repository. - - Returns: - (str | None): The origin URL of the git repository or None if not git directory. - """ - if IS_GIT_DIR: - try: - origin = subprocess.check_output(["git", "config", "--get", "remote.origin.url"]) - return origin.decode().strip() - except subprocess.CalledProcessError: - return None - - -def get_git_branch(): - """ - Return the current git branch name. If not in a git repository, return None. - - Returns: - (str | None): The current git branch name or None if not a git directory. - """ - if IS_GIT_DIR: - try: - origin = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) - return origin.decode().strip() - except subprocess.CalledProcessError: - return None - - -def get_default_args(func): - """ - Return a dictionary of default arguments for a function. - - Args: - func (callable): The function to inspect. - - Returns: - (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter. - """ - signature = inspect.signature(func) - return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty} - - -def get_ubuntu_version(): - """ - Retrieve the Ubuntu version if the OS is Ubuntu. - - Returns: - (str): Ubuntu version or None if not an Ubuntu OS. - """ - if is_ubuntu(): - try: - with open("/etc/os-release") as f: - return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1] - except (FileNotFoundError, AttributeError): - return None - - -def get_user_config_dir(sub_dir="Ultralytics"): - """ - Return the appropriate config directory based on the environment operating system. - - Args: - sub_dir (str): The name of the subdirectory to create. - - Returns: - (Path): The path to the user config directory. - """ - if WINDOWS: - path = Path.home() / "AppData" / "Roaming" / sub_dir - elif MACOS: # macOS - path = Path.home() / "Library" / "Application Support" / sub_dir - elif LINUX: - path = Path.home() / ".config" / sub_dir - else: - raise ValueError(f"Unsupported operating system: {platform.system()}") - - # GCP and AWS lambda fix, only /tmp is writeable - if not is_dir_writeable(path.parent): - LOGGER.warning( - f"user config directory '{path}' is not writeable, defaulting to '/tmp' or CWD. " - "Alternatively you can define a YOLO_CONFIG_DIR environment variable for this path." - ) - path = Path("/tmp") / sub_dir if is_dir_writeable("/tmp") else Path().cwd() / sub_dir - - # Create the subdirectory if it does not exist - path.mkdir(parents=True, exist_ok=True) - - return path - - -# Define constants (required below) -DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant -ONLINE = is_online() -IS_COLAB = is_colab() -IS_KAGGLE = is_kaggle() -IS_DOCKER = is_docker() -IS_JETSON = is_jetson() -IS_JUPYTER = is_jupyter() -IS_PIP_PACKAGE = is_pip_package() -IS_RASPBERRYPI = is_raspberrypi() -GIT_DIR = get_git_dir() -IS_GIT_DIR = is_git_dir() -USER_CONFIG_DIR = Path(os.getenv("YOLO_CONFIG_DIR") or get_user_config_dir()) # Ultralytics settings dir -SETTINGS_FILE = USER_CONFIG_DIR / "settings.json" - - -def colorstr(*input): - r""" - Color a string based on the provided color and style arguments using ANSI escape codes. - - This function can be called in two ways: - - colorstr('color', 'style', 'your string') - - colorstr('your string') - - In the second form, 'blue' and 'bold' will be applied by default. - - Args: - *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments, - and the last string is the one to be colored. - - Returns: - (str): The input string wrapped with ANSI escape codes for the specified color and style. - - Notes: - Supported Colors and Styles: - - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white' - - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow', - 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white' - - Misc: 'end', 'bold', 'underline' - - Examples: - >>> colorstr("blue", "bold", "hello world") - >>> "\033[34m\033[1mhello world\033[0m" - - References: - https://en.wikipedia.org/wiki/ANSI_escape_code - """ - *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string - colors = { - "black": "\033[30m", # basic colors - "red": "\033[31m", - "green": "\033[32m", - "yellow": "\033[33m", - "blue": "\033[34m", - "magenta": "\033[35m", - "cyan": "\033[36m", - "white": "\033[37m", - "bright_black": "\033[90m", # bright colors - "bright_red": "\033[91m", - "bright_green": "\033[92m", - "bright_yellow": "\033[93m", - "bright_blue": "\033[94m", - "bright_magenta": "\033[95m", - "bright_cyan": "\033[96m", - "bright_white": "\033[97m", - "end": "\033[0m", # misc - "bold": "\033[1m", - "underline": "\033[4m", - } - return "".join(colors[x] for x in args) + f"{string}" + colors["end"] - - -def remove_colorstr(input_string): - """ - Remove ANSI escape codes from a string, effectively un-coloring it. - - Args: - input_string (str): The string to remove color and style from. - - Returns: - (str): A new string with all ANSI escape codes removed. - - Examples: - >>> remove_colorstr(colorstr("blue", "bold", "hello world")) - >>> "hello world" - """ - ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]") - return ansi_escape.sub("", input_string) - - -class TryExcept(contextlib.ContextDecorator): - """ - Ultralytics TryExcept class for handling exceptions gracefully. - - This class can be used as a decorator or context manager to catch exceptions and optionally print warning messages. - It allows code to continue execution even when exceptions occur, which is useful for non-critical operations. - - Attributes: - msg (str): Optional message to display when an exception occurs. - verbose (bool): Whether to print the exception message. - - Examples: - As a decorator: - >>> @TryExcept(msg="Error occurred in func", verbose=True) - >>> def func(): - >>> # Function logic here - >>> pass - - As a context manager: - >>> with TryExcept(msg="Error occurred in block", verbose=True): - >>> # Code block here - >>> pass - """ - - def __init__(self, msg="", verbose=True): - """Initialize TryExcept class with optional message and verbosity settings.""" - self.msg = msg - self.verbose = verbose - - def __enter__(self): - """Execute when entering TryExcept context, initialize instance.""" - pass - - def __exit__(self, exc_type, value, traceback): - """Define behavior when exiting a 'with' block, print error message if necessary.""" - if self.verbose and value: - LOGGER.warning(f"{self.msg}{': ' if self.msg else ''}{value}") - return True - - -class Retry(contextlib.ContextDecorator): - """ - Retry class for function execution with exponential backoff. - - This decorator can be used to retry a function on exceptions, up to a specified number of times with an - exponentially increasing delay between retries. It's useful for handling transient failures in network - operations or other unreliable processes. - - Attributes: - times (int): Maximum number of retry attempts. - delay (int): Initial delay between retries in seconds. - - Examples: - Example usage as a decorator: - >>> @Retry(times=3, delay=2) - >>> def test_func(): - >>> # Replace with function logic that may raise exceptions - >>> return True - """ - - def __init__(self, times=3, delay=2): - """Initialize Retry class with specified number of retries and delay.""" - self.times = times - self.delay = delay - self._attempts = 0 - - def __call__(self, func): - """Decorator implementation for Retry with exponential backoff.""" - - def wrapped_func(*args, **kwargs): - """Apply retries to the decorated function or method.""" - self._attempts = 0 - while self._attempts < self.times: - try: - return func(*args, **kwargs) - except Exception as e: - self._attempts += 1 - LOGGER.warning(f"Retry {self._attempts}/{self.times} failed: {e}") - if self._attempts >= self.times: - raise e - time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay - - return wrapped_func - - -def threaded(func): - """ - Multi-thread a target function by default and return the thread or function result. - - This decorator provides flexible execution of the target function, either in a separate thread or synchronously. - By default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument - which is removed from kwargs before calling the function. - - Args: - func (callable): The function to be potentially executed in a separate thread. - - Returns: - (callable): A wrapper function that either returns a daemon thread or the direct function result. - - Examples: - >>> @threaded - ... def process_data(data): - ... return data - >>> - >>> thread = process_data(my_data) # Runs in background thread - >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result - """ - - def wrapper(*args, **kwargs): - """Multi-thread a given function based on 'threaded' kwarg and return the thread or function result.""" - if kwargs.pop("threaded", True): # run in thread - thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True) - thread.start() - return thread - else: - return func(*args, **kwargs) - - return wrapper - - -def set_sentry(): - """ - Initialize the Sentry SDK for error tracking and reporting. - - Only used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update - settings. - - Conditions required to send errors (ALL conditions must be met or no errors will be reported): - - sentry_sdk package is installed - - sync=True in YOLO settings - - pytest is not running - - running in a pip package installation - - running in a non-git directory - - running with rank -1 or 0 - - online environment - - CLI used to run package (checked with 'yolo' as the name of the main CLI command) - """ - if ( - not SETTINGS["sync"] - or RANK not in {-1, 0} - or Path(ARGV[0]).name != "yolo" - or TESTS_RUNNING - or not ONLINE - or not IS_PIP_PACKAGE - or IS_GIT_DIR - ): - return - # If sentry_sdk package is not installed then return and do not use Sentry - try: - import sentry_sdk # noqa - except ImportError: - return - - def before_send(event, hint): - """ - Modify the event before sending it to Sentry based on specific exception types and messages. - - Args: - event (dict): The event dictionary containing information about the error. - hint (dict): A dictionary containing additional information about the error. - - Returns: - (dict | None): The modified event or None if the event should not be sent to Sentry. - """ - if "exc_info" in hint: - exc_type, exc_value, _ = hint["exc_info"] - if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value): - return None # do not send event - - event["tags"] = { - "sys_argv": ARGV[0], - "sys_argv_name": Path(ARGV[0]).name, - "install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", - "os": ENVIRONMENT, - } - return event - - sentry_sdk.init( - dsn="https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016", - debug=False, - auto_enabling_integrations=False, - traces_sample_rate=1.0, - release=__version__, - environment="runpod" if is_runpod() else "production", - before_send=before_send, - ignore_errors=[KeyboardInterrupt, FileNotFoundError], - ) - sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash - - -class JSONDict(dict): - """ - A dictionary-like class that provides JSON persistence for its contents. - - This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are - modified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects. - - Attributes: - file_path (Path): The path to the JSON file used for persistence. - lock (threading.Lock): A lock object to ensure thread-safe operations. - - Methods: - _load: Load the data from the JSON file into the dictionary. - _save: Save the current state of the dictionary to the JSON file. - __setitem__: Store a key-value pair and persist it to disk. - __delitem__: Remove an item and update the persistent storage. - update: Update the dictionary and persist changes. - clear: Clear all entries and update the persistent storage. - - Examples: - >>> json_dict = JSONDict("data.json") - >>> json_dict["key"] = "value" - >>> print(json_dict["key"]) - value - >>> del json_dict["key"] - >>> json_dict.update({"new_key": "new_value"}) - >>> json_dict.clear() - """ - - def __init__(self, file_path: Union[str, Path] = "data.json"): - """Initialize a JSONDict object with a specified file path for JSON persistence.""" - super().__init__() - self.file_path = Path(file_path) - self.lock = Lock() - self._load() - - def _load(self): - """Load the data from the JSON file into the dictionary.""" - try: - if self.file_path.exists(): - with open(self.file_path) as f: - self.update(json.load(f)) - except json.JSONDecodeError: - LOGGER.warning(f"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.") - except Exception as e: - LOGGER.error(f"Error reading from {self.file_path}: {e}") - - def _save(self): - """Save the current state of the dictionary to the JSON file.""" - try: - self.file_path.parent.mkdir(parents=True, exist_ok=True) - with open(self.file_path, "w", encoding="utf-8") as f: - json.dump(dict(self), f, indent=2, default=self._json_default) - except Exception as e: - LOGGER.error(f"Error writing to {self.file_path}: {e}") - - @staticmethod - def _json_default(obj): - """Handle JSON serialization of Path objects.""" - if isinstance(obj, Path): - return str(obj) - raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") - - def __setitem__(self, key, value): - """Store a key-value pair and persist to disk.""" - with self.lock: - super().__setitem__(key, value) - self._save() - - def __delitem__(self, key): - """Remove an item and update the persistent storage.""" - with self.lock: - super().__delitem__(key) - self._save() - - def __str__(self): - """Return a pretty-printed JSON string representation of the dictionary.""" - contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default) - return f'JSONDict("{self.file_path}"):\n{contents}' - - def update(self, *args, **kwargs): - """Update the dictionary and persist changes.""" - with self.lock: - super().update(*args, **kwargs) - self._save() - - def clear(self): - """Clear all entries and update the persistent storage.""" - with self.lock: - super().clear() - self._save() - - -class SettingsManager(JSONDict): - """ - SettingsManager class for managing and persisting Ultralytics settings. - - This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default - values. It validates settings on initialization and provides methods to update or reset settings. The settings - include directories for datasets, weights, and runs, as well as various integration flags. - - Attributes: - file (Path): The path to the JSON file used for persistence. - version (str): The version of the settings schema. - defaults (dict): A dictionary containing default settings. - help_msg (str): A help message for users on how to view and update settings. - - Methods: - _validate_settings: Validate the current settings and reset if necessary. - update: Update settings, validating keys and types. - reset: Reset the settings to default and save them. - - Examples: - Initialize and update settings: - >>> settings = SettingsManager() - >>> settings.update(runs_dir="/new/runs/dir") - >>> print(settings["runs_dir"]) - /new/runs/dir - """ - - def __init__(self, file=SETTINGS_FILE, version="0.0.6"): - """Initialize the SettingsManager with default settings and load user settings.""" - import hashlib - import uuid - - from ultralytics.utils.torch_utils import torch_distributed_zero_first - - root = GIT_DIR or Path() - datasets_root = (root.parent if GIT_DIR and is_dir_writeable(root.parent) else root).resolve() - - self.file = Path(file) - self.version = version - self.defaults = { - "settings_version": version, # Settings schema version - "datasets_dir": str(datasets_root / "datasets"), # Datasets directory - "weights_dir": str(root / "weights"), # Model weights directory - "runs_dir": str(root / "runs"), # Experiment runs directory - "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash - "sync": True, # Enable synchronization - "api_key": "", # Ultralytics API Key - "openai_api_key": "", # OpenAI API Key - "clearml": True, # ClearML integration - "comet": True, # Comet integration - "dvc": True, # DVC integration - "hub": True, # Ultralytics HUB integration - "mlflow": True, # MLflow integration - "neptune": True, # Neptune integration - "raytune": True, # Ray Tune integration - "tensorboard": False, # TensorBoard logging - "wandb": False, # Weights & Biases logging - "vscode_msg": True, # VSCode message - "openvino_msg": True, # OpenVINO export on Intel CPU message - } - - self.help_msg = ( - f"\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'" - "\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. " - "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings." - ) - - with torch_distributed_zero_first(LOCAL_RANK): - super().__init__(self.file) - - if not self.file.exists() or not self: # Check if file doesn't exist or is empty - LOGGER.info(f"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}") - self.reset() - - self._validate_settings() - - def _validate_settings(self): - """Validate the current settings and reset if necessary.""" - correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys()) - correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items()) - correct_version = self.get("settings_version", "") == self.version - - if not (correct_keys and correct_types and correct_version): - LOGGER.warning( - "Ultralytics settings reset to default values. This may be due to a possible problem " - f"with your settings or a recent ultralytics package update. {self.help_msg}" - ) - self.reset() - - if self.get("datasets_dir") == self.get("runs_dir"): - LOGGER.warning( - f"Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' " - f"must be different than 'runs_dir: {self.get('runs_dir')}'. " - f"Please change one to avoid possible issues during training. {self.help_msg}" - ) - - def __setitem__(self, key, value): - """Update one key: value pair.""" - self.update({key: value}) - - def update(self, *args, **kwargs): - """Update settings, validating keys and types.""" - for arg in args: - if isinstance(arg, dict): - kwargs.update(arg) - for k, v in kwargs.items(): - if k not in self.defaults: - raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}") - t = type(self.defaults[k]) - if not isinstance(v, t): - raise TypeError( - f"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}" - ) - super().update(*args, **kwargs) - - def reset(self): - """Reset the settings to default and save them.""" - self.clear() - self.update(self.defaults) - - -def deprecation_warn(arg, new_arg=None): - """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument.""" - msg = f"'{arg}' is deprecated and will be removed in in the future." - if new_arg is not None: - msg += f" Use '{new_arg}' instead." - LOGGER.warning(msg) - - -def clean_url(url): - """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt.""" - url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows - return unquote(url).split("?", 1)[0] # '%2F' to '/', split https://url.com/file.txt?auth - - -def url2file(url): - """Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt.""" - return Path(clean_url(url)).name - - -def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str: - """Display a message to install Ultralytics-Snippets for VS Code if not already installed.""" - path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / ".vscode/extensions" - obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains - installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "") - url = "https://docs.ultralytics.com/integrations/vscode" - return "" if installed else f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}" - - -# Run below code on utils init ------------------------------------------------------------------------------------ - -# Check first-install steps -PREFIX = colorstr("Ultralytics: ") -SETTINGS = SettingsManager() # initialize settings -PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / "persistent_cache.json") # initialize persistent cache -DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory -WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory -RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory -ENVIRONMENT = ( - "Colab" - if IS_COLAB - else "Kaggle" - if IS_KAGGLE - else "Jupyter" - if IS_JUPYTER - else "Docker" - if IS_DOCKER - else platform.system() -) -TESTS_RUNNING = is_pytest_running() or is_github_action_running() -set_sentry() - -# Apply monkey patches -torch.save = torch_save -if WINDOWS: - # Apply cv2 patches for non-ASCII and non-UTF characters in image paths - cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/autobatch.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/autobatch.py deleted file mode 100644 index 3dfd8c7..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/autobatch.py +++ /dev/null @@ -1,119 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch.""" - -import os -from copy import deepcopy -from typing import Union - -import numpy as np -import torch - -from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr -from ultralytics.utils.torch_utils import autocast, profile_ops - - -def check_train_batch_size( - model: torch.nn.Module, - imgsz: int = 640, - amp: bool = True, - batch: Union[int, float] = -1, - max_num_obj: int = 1, -) -> int: - """ - Compute optimal YOLO training batch size using the autobatch() function. - - Args: - model (torch.nn.Module): YOLO model to check batch size for. - imgsz (int, optional): Image size used for training. - amp (bool, optional): Use automatic mixed precision if True. - batch (int | float, optional): Fraction of GPU memory to use. If -1, use default. - max_num_obj (int, optional): The maximum number of objects from dataset. - - Returns: - (int): Optimal batch size computed using the autobatch() function. - - Notes: - If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use. - Otherwise, a default fraction of 0.6 is used. - """ - with autocast(enabled=amp): - return autobatch( - deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj - ) - - -def autobatch( - model: torch.nn.Module, - imgsz: int = 640, - fraction: float = 0.60, - batch_size: int = DEFAULT_CFG.batch, - max_num_obj: int = 1, -) -> int: - """ - Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory. - - Args: - model (torch.nn.Module): YOLO model to compute batch size for. - imgsz (int, optional): The image size used as input for the YOLO model. - fraction (float, optional): The fraction of available CUDA memory to use. - batch_size (int, optional): The default batch size to use if an error is detected. - max_num_obj (int, optional): The maximum number of objects from dataset. - - Returns: - (int): The optimal batch size. - """ - # Check device - prefix = colorstr("AutoBatch: ") - LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.") - device = next(model.parameters()).device # get model device - if device.type in {"cpu", "mps"}: - LOGGER.warning(f"{prefix}intended for CUDA devices, using default batch-size {batch_size}") - return batch_size - if torch.backends.cudnn.benchmark: - LOGGER.warning(f"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}") - return batch_size - - # Inspect CUDA memory - gb = 1 << 30 # bytes to GiB (1024 ** 3) - d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0' - properties = torch.cuda.get_device_properties(device) # device properties - t = properties.total_memory / gb # GiB total - r = torch.cuda.memory_reserved(device) / gb # GiB reserved - a = torch.cuda.memory_allocated(device) / gb # GiB allocated - f = t - (r + a) # GiB free - LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free") - - # Profile batch sizes - batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64] - try: - img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] - results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj) - - # Fit a solution - xy = [ - [x, y[2]] - for i, (x, y) in enumerate(zip(batch_sizes, results)) - if y # valid result - and isinstance(y[2], (int, float)) # is numeric - and 0 < y[2] < t # between 0 and GPU limit - and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory - ] - fit_x, fit_y = zip(*xy) if xy else ([], []) - p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space - b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size) - if None in results: # some sizes failed - i = results.index(None) # first fail index - if b >= batch_sizes[i]: # y intercept above failure point - b = batch_sizes[max(i - 1, 0)] # select prior safe point - if b < 1 or b > 1024: # b outside of safe range - LOGGER.warning(f"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.") - b = batch_size - - fraction = (np.polyval(p, b) + r + a) / t # predicted fraction - LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅") - return b - except Exception as e: - LOGGER.warning(f"{prefix}error detected: {e}, using default batch-size {batch_size}.") - return batch_size - finally: - torch.cuda.empty_cache() diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/autodevice.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/autodevice.py deleted file mode 100644 index e8c93b4..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/autodevice.py +++ /dev/null @@ -1,206 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from typing import Any, Dict, List, Optional - -from ultralytics.utils import LOGGER -from ultralytics.utils.checks import check_requirements - - -class GPUInfo: - """ - Manages NVIDIA GPU information via pynvml with robust error handling. - - Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle - GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml - library by logging warnings and disabling related features, preventing application crashes. - - Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU - selection. Manages NVML initialization and shutdown internally. - - Attributes: - pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`. - nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded, - False otherwise. - gpu_stats (List[Dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on - initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%), - 'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W), - 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail. - - Methods: - refresh_stats: Refresh the internal gpu_stats list by querying NVML. - print_status: Print GPU status in a compact table format using current stats. - select_idle_gpu: Select the most idle GPUs based on utilization and free memory. - shutdown: Shut down NVML if it was initialized. - - Examples: - Initialize GPUInfo and print status - >>> gpu_info = GPUInfo() - >>> gpu_info.print_status() - - Select idle GPUs with minimum memory requirements - >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2) - >>> print(f"Selected GPU indices: {selected}") - """ - - def __init__(self): - """Initialize GPUInfo, attempting to import and initialize pynvml.""" - self.pynvml: Optional[Any] = None - self.nvml_available: bool = False - self.gpu_stats: List[Dict[str, Any]] = [] - - try: - check_requirements("pynvml>=12.0.0") - self.pynvml = __import__("pynvml") - self.pynvml.nvmlInit() - self.nvml_available = True - self.refresh_stats() - except Exception as e: - LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}") - - def __del__(self): - """Ensure NVML is shut down when the object is garbage collected.""" - self.shutdown() - - def shutdown(self): - """Shut down NVML if it was initialized.""" - if self.nvml_available and self.pynvml: - try: - self.pynvml.nvmlShutdown() - except Exception: - pass - self.nvml_available = False - - def refresh_stats(self): - """Refresh the internal gpu_stats list by querying NVML.""" - self.gpu_stats = [] - if not self.nvml_available or not self.pynvml: - return - - try: - device_count = self.pynvml.nvmlDeviceGetCount() - for i in range(device_count): - self.gpu_stats.append(self._get_device_stats(i)) - except Exception as e: - LOGGER.warning(f"Error during device query: {e}") - self.gpu_stats = [] - - def _get_device_stats(self, index: int) -> Dict[str, Any]: - """Get stats for a single GPU device.""" - handle = self.pynvml.nvmlDeviceGetHandleByIndex(index) - memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle) - util = self.pynvml.nvmlDeviceGetUtilizationRates(handle) - - def safe_get(func, *args, default=-1, divisor=1): - try: - val = func(*args) - return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val - except Exception: - return default - - temp_type = getattr(self.pynvml, "NVML_TEMPERATURE_GPU", -1) - - return { - "index": index, - "name": self.pynvml.nvmlDeviceGetName(handle), - "utilization": util.gpu if util else -1, - "memory_used": memory.used >> 20 if memory else -1, # Convert bytes to MiB - "memory_total": memory.total >> 20 if memory else -1, - "memory_free": memory.free >> 20 if memory else -1, - "temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type), - "power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W - "power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000), - } - - def print_status(self): - """Print GPU status in a compact table format using current stats.""" - self.refresh_stats() - if not self.gpu_stats: - LOGGER.warning("No GPU stats available.") - return - - stats = self.gpu_stats - name_len = max(len(gpu.get("name", "N/A")) for gpu in stats) - hdr = f"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}" - LOGGER.info(f"\n--- GPU Status ---\n{hdr}\n{'-' * len(hdr)}") - - for gpu in stats: - u = f"{gpu['utilization']:>5}%" if gpu["utilization"] >= 0 else " N/A " - m = f"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}" if gpu["memory_used"] >= 0 else " N/A / N/A " - t = f"{gpu['temperature']}C" if gpu["temperature"] >= 0 else " N/A " - p = f"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}" if gpu["power_draw"] >= 0 else " N/A " - - LOGGER.info(f"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}") - - LOGGER.info(f"{'-' * len(hdr)}\n") - - def select_idle_gpu( - self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0 - ) -> List[int]: - """ - Select the most idle GPUs based on utilization and free memory. - - Args: - count (int): The number of idle GPUs to select. - min_memory_fraction (float): Minimum free memory required as a fraction of total memory. - min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0. - - Returns: - (List[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first). - - Notes: - Returns fewer than 'count' if not enough qualify or exist. - Returns basic CUDA indices if NVML fails. Empty list if no GPUs found. - """ - assert min_memory_fraction <= 1.0, f"min_memory_fraction must be <= 1.0, got {min_memory_fraction}" - assert min_util_fraction <= 1.0, f"min_util_fraction must be <= 1.0, got {min_util_fraction}" - LOGGER.info( - f"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%..." - ) - - if count <= 0: - return [] - - self.refresh_stats() - if not self.gpu_stats: - LOGGER.warning("NVML stats unavailable.") - return [] - - # Filter and sort eligible GPUs - eligible_gpus = [ - gpu - for gpu in self.gpu_stats - if gpu.get("memory_free", 0) / gpu.get("memory_total", 1) >= min_memory_fraction - and (100 - gpu.get("utilization", 100)) >= min_util_fraction * 100 - ] - eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0))) - - # Select top 'count' indices - selected = [gpu["index"] for gpu in eligible_gpus[:count]] - - if selected: - LOGGER.info(f"Selected idle CUDA devices {selected}") - else: - LOGGER.warning( - f"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%)." - ) - - return selected - - -if __name__ == "__main__": - required_free_mem_fraction = 0.2 # Require 20% free VRAM - required_free_util_fraction = 0.2 # Require 20% free utilization - num_gpus_to_select = 1 - - gpu_info = GPUInfo() - gpu_info.print_status() - - selected = gpu_info.select_idle_gpu( - count=num_gpus_to_select, - min_memory_fraction=required_free_mem_fraction, - min_util_fraction=required_free_util_fraction, - ) - if selected: - print(f"\n==> Using selected GPU indices: {selected}") - devices = [f"cuda:{idx}" for idx in selected] - print(f" Target devices: {devices}") diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/benchmarks.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/benchmarks.py deleted file mode 100644 index df92cd1..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/benchmarks.py +++ /dev/null @@ -1,720 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -""" -Benchmark a YOLO model formats for speed and accuracy. - -Usage: - from ultralytics.utils.benchmarks import ProfileModels, benchmark - ProfileModels(['yolo11n.yaml', 'yolov8s.yaml']).run() - benchmark(model='yolo11n.pt', imgsz=160) - -Format | `format=argument` | Model ---- | --- | --- -PyTorch | - | yolo11n.pt -TorchScript | `torchscript` | yolo11n.torchscript -ONNX | `onnx` | yolo11n.onnx -OpenVINO | `openvino` | yolo11n_openvino_model/ -TensorRT | `engine` | yolo11n.engine -CoreML | `coreml` | yolo11n.mlpackage -TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/ -TensorFlow GraphDef | `pb` | yolo11n.pb -TensorFlow Lite | `tflite` | yolo11n.tflite -TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite -TensorFlow.js | `tfjs` | yolo11n_web_model/ -PaddlePaddle | `paddle` | yolo11n_paddle_model/ -MNN | `mnn` | yolo11n.mnn -NCNN | `ncnn` | yolo11n_ncnn_model/ -IMX | `imx` | yolo11n_imx_model/ -RKNN | `rknn` | yolo11n_rknn_model/ -""" - -import glob -import os -import platform -import re -import shutil -import time -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch.cuda - -from ultralytics import YOLO, YOLOWorld -from ultralytics.cfg import TASK2DATA, TASK2METRIC -from ultralytics.engine.exporter import export_formats -from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML -from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip -from ultralytics.utils.downloads import safe_download -from ultralytics.utils.files import file_size -from ultralytics.utils.torch_utils import get_cpu_info, select_device - - -def benchmark( - model=WEIGHTS_DIR / "yolo11n.pt", - data=None, - imgsz=160, - half=False, - int8=False, - device="cpu", - verbose=False, - eps=1e-3, - format="", - **kwargs, -): - """ - Benchmark a YOLO model across different formats for speed and accuracy. - - Args: - model (str | Path): Path to the model file or directory. - data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed. - imgsz (int): Image size for the benchmark. - half (bool): Use half-precision for the model if True. - int8 (bool): Use int8-precision for the model if True. - device (str): Device to run the benchmark on, either 'cpu' or 'cuda'. - verbose (bool | float): If True or a float, assert benchmarks pass with given metric. - eps (float): Epsilon value for divide by zero prevention. - format (str): Export format for benchmarking. If not supplied all formats are benchmarked. - **kwargs (Any): Additional keyword arguments for exporter. - - Returns: - (pandas.DataFrame): A pandas DataFrame with benchmark results for each format, including file size, metric, - and inference time. - - Examples: - Benchmark a YOLO model with default settings: - >>> from ultralytics.utils.benchmarks import benchmark - >>> benchmark(model="yolo11n.pt", imgsz=640) - """ - imgsz = check_imgsz(imgsz) - assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz." - - import pandas as pd # scope for faster 'import ultralytics' - - pd.options.display.max_columns = 10 - pd.options.display.width = 120 - device = select_device(device, verbose=False) - if isinstance(model, (str, Path)): - model = YOLO(model) - is_end2end = getattr(model.model.model[-1], "end2end", False) - data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect - key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect - - y = [] - t0 = time.time() - - format_arg = format.lower() - if format_arg: - formats = frozenset(export_formats()["Argument"]) - assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'." - for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()): - emoji, filename = "❌", None # export defaults - try: - if format_arg and format_arg != format: - continue - - # Checks - if format == "pb": - assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task" - elif format == "edgetpu": - assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux" - elif format in {"coreml", "tfjs"}: - assert MACOS or (LINUX and not ARM64), ( - "CoreML and TF.js export only supported on macOS and non-aarch64 Linux" - ) - if format == "coreml": - assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13" - if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet" - # assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet" - if format == "paddle": - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet" - assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024" - assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet" - assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet" - if format == "mnn": - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet" - if format == "ncnn": - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet" - if format == "imx": - assert not is_end2end - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported" - assert model.task == "detect", "IMX only supported for detection task" - assert "C2f" in model.__str__(), "IMX only supported for YOLOv8" # TODO: enable for YOLO11 - if format == "rknn": - assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet" - assert not is_end2end, "End-to-end models not supported by RKNN yet" - assert LINUX, "RKNN only supported on Linux" - assert not is_rockchip(), "RKNN Inference only supported on Rockchip devices" - if "cpu" in device.type: - assert cpu, "inference not supported on CPU" - if "cuda" in device.type: - assert gpu, "inference not supported on GPU" - - # Export - if format == "-": - filename = model.pt_path or model.ckpt_path or model.model_name - exported_model = model # PyTorch format - else: - filename = model.export( - imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs - ) - exported_model = YOLO(filename, task=model.task) - assert suffix in str(filename), "export failed" - emoji = "❎" # indicates export succeeded - - # Predict - assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported" - assert format not in {"edgetpu", "tfjs"}, "inference not supported" - assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13" - if format == "ncnn": - assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet" - exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False) - - # Validate - results = exported_model.val( - data=data, - batch=1, - imgsz=imgsz, - plots=False, - device=device, - half=half, - int8=int8, - verbose=False, - conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001 - ) - metric, speed = results.results_dict[key], results.speed["inference"] - fps = round(1000 / (speed + eps), 2) # frames per second - y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps]) - except Exception as e: - if verbose: - assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}" - LOGGER.error(f"Benchmark failure for {name}: {e}") - y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference - - # Print results - check_yolo(device=device) # print system info - df = pd.DataFrame(y, columns=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"]) - - name = model.model_name - dt = time.time() - t0 - legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed" - s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df.fillna('-')}\n" - LOGGER.info(s) - with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f: - f.write(s) - - if verbose and isinstance(verbose, float): - metrics = df[key].array # values to compare to floor - floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n - assert all(x > floor for x in metrics if pd.notna(x)), f"Benchmark failure: metric(s) < floor {floor}" - - return df - - -class RF100Benchmark: - """ - Benchmark YOLO model performance across various formats for speed and accuracy. - - This class provides functionality to benchmark YOLO models on the RF100 dataset collection. - - Attributes: - ds_names (List[str]): Names of datasets used for benchmarking. - ds_cfg_list (List[Path]): List of paths to dataset configuration files. - rf (Roboflow): Roboflow instance for accessing datasets. - val_metrics (List[str]): Metrics used for validation. - - Methods: - set_key: Set Roboflow API key for accessing datasets. - parse_dataset: Parse dataset links and download datasets. - fix_yaml: Fix train and validation paths in YAML files. - evaluate: Evaluate model performance on validation results. - """ - - def __init__(self): - """Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats.""" - self.ds_names = [] - self.ds_cfg_list = [] - self.rf = None - self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"] - - def set_key(self, api_key: str): - """ - Set Roboflow API key for processing. - - Args: - api_key (str): The API key. - - Examples: - Set the Roboflow API key for accessing datasets: - >>> benchmark = RF100Benchmark() - >>> benchmark.set_key("your_roboflow_api_key") - """ - check_requirements("roboflow") - from roboflow import Roboflow - - self.rf = Roboflow(api_key=api_key) - - def parse_dataset(self, ds_link_txt: str = "datasets_links.txt"): - """ - Parse dataset links and download datasets. - - Args: - ds_link_txt (str): Path to the file containing dataset links. - - Returns: - ds_names (List[str]): List of dataset names. - ds_cfg_list (List[Path]): List of paths to dataset configuration files. - - Examples: - >>> benchmark = RF100Benchmark() - >>> benchmark.set_key("api_key") - >>> benchmark.parse_dataset("datasets_links.txt") - """ - (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100") - os.chdir("rf-100") - os.mkdir("ultralytics-benchmarks") - safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt") - - with open(ds_link_txt, encoding="utf-8") as file: - for line in file: - try: - _, url, workspace, project, version = re.split("/+", line.strip()) - self.ds_names.append(project) - proj_version = f"{project}-{version}" - if not Path(proj_version).exists(): - self.rf.workspace(workspace).project(project).version(version).download("yolov8") - else: - LOGGER.info("Dataset already downloaded.") - self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml") - except Exception: - continue - - return self.ds_names, self.ds_cfg_list - - @staticmethod - def fix_yaml(path: Path): - """Fix the train and validation paths in a given YAML file.""" - yaml_data = YAML.load(path) - yaml_data["train"] = "train/images" - yaml_data["val"] = "valid/images" - YAML.dump(yaml_data, path) - - def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int): - """ - Evaluate model performance on validation results. - - Args: - yaml_path (str): Path to the YAML configuration file. - val_log_file (str): Path to the validation log file. - eval_log_file (str): Path to the evaluation log file. - list_ind (int): Index of the current dataset in the list. - - Returns: - (float): The mean average precision (mAP) value for the evaluated model. - - Examples: - Evaluate a model on a specific dataset - >>> benchmark = RF100Benchmark() - >>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0) - """ - skip_symbols = ["🚀", "⚠️", "💡", "❌"] - class_names = YAML.load(yaml_path)["names"] - with open(val_log_file, encoding="utf-8") as f: - lines = f.readlines() - eval_lines = [] - for line in lines: - if any(symbol in line for symbol in skip_symbols): - continue - entries = line.split(" ") - entries = list(filter(lambda val: val != "", entries)) - entries = [e.strip("\n") for e in entries] - eval_lines.extend( - { - "class": entries[0], - "images": entries[1], - "targets": entries[2], - "precision": entries[3], - "recall": entries[4], - "map50": entries[5], - "map95": entries[6], - } - for e in entries - if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries) - ) - map_val = 0.0 - if len(eval_lines) > 1: - LOGGER.info("Multiple dicts found") - for lst in eval_lines: - if lst["class"] == "all": - map_val = lst["map50"] - else: - LOGGER.info("Single dict found") - map_val = [res["map50"] for res in eval_lines][0] - - with open(eval_log_file, "a", encoding="utf-8") as f: - f.write(f"{self.ds_names[list_ind]}: {map_val}\n") - - return float(map_val) - - -class ProfileModels: - """ - ProfileModels class for profiling different models on ONNX and TensorRT. - - This class profiles the performance of different models, returning results such as model speed and FLOPs. - - Attributes: - paths (List[str]): Paths of the models to profile. - num_timed_runs (int): Number of timed runs for the profiling. - num_warmup_runs (int): Number of warmup runs before profiling. - min_time (float): Minimum number of seconds to profile for. - imgsz (int): Image size used in the models. - half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling. - trt (bool): Flag to indicate whether to profile using TensorRT. - device (torch.device): Device used for profiling. - - Methods: - run: Profile YOLO models for speed and accuracy across various formats. - get_files: Get all relevant model files. - get_onnx_model_info: Extract metadata from an ONNX model. - iterative_sigma_clipping: Apply sigma clipping to remove outliers. - profile_tensorrt_model: Profile a TensorRT model. - profile_onnx_model: Profile an ONNX model. - generate_table_row: Generate a table row with model metrics. - generate_results_dict: Generate a dictionary of profiling results. - print_table: Print a formatted table of results. - - Examples: - Profile models and print results - >>> from ultralytics.utils.benchmarks import ProfileModels - >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640) - >>> profiler.run() - """ - - def __init__( - self, - paths: List[str], - num_timed_runs: int = 100, - num_warmup_runs: int = 10, - min_time: float = 60, - imgsz: int = 640, - half: bool = True, - trt: bool = True, - device: Optional[Union[torch.device, str]] = None, - ): - """ - Initialize the ProfileModels class for profiling models. - - Args: - paths (List[str]): List of paths of the models to be profiled. - num_timed_runs (int): Number of timed runs for the profiling. - num_warmup_runs (int): Number of warmup runs before the actual profiling starts. - min_time (float): Minimum time in seconds for profiling a model. - imgsz (int): Size of the image used during profiling. - half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling. - trt (bool): Flag to indicate whether to profile using TensorRT. - device (torch.device | str | None): Device used for profiling. If None, it is determined automatically. - - Notes: - FP16 'half' argument option removed for ONNX as slower on CPU than FP32. - - Examples: - Initialize and profile models - >>> from ultralytics.utils.benchmarks import ProfileModels - >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640) - >>> profiler.run() - """ - self.paths = paths - self.num_timed_runs = num_timed_runs - self.num_warmup_runs = num_warmup_runs - self.min_time = min_time - self.imgsz = imgsz - self.half = half - self.trt = trt # run TensorRT profiling - self.device = device if isinstance(device, torch.device) else select_device(device) - - def run(self): - """ - Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT. - - Returns: - (List[dict]): List of dictionaries containing profiling results for each model. - - Examples: - Profile models and print results - >>> from ultralytics.utils.benchmarks import ProfileModels - >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"]) - >>> results = profiler.run() - """ - files = self.get_files() - - if not files: - LOGGER.warning("No matching *.pt or *.onnx files found.") - return [] - - table_rows = [] - output = [] - for file in files: - engine_file = file.with_suffix(".engine") - if file.suffix in {".pt", ".yaml", ".yml"}: - model = YOLO(str(file)) - model.fuse() # to report correct params and GFLOPs in model.info() - model_info = model.info() - if self.trt and self.device.type != "cpu" and not engine_file.is_file(): - engine_file = model.export( - format="engine", - half=self.half, - imgsz=self.imgsz, - device=self.device, - verbose=False, - ) - onnx_file = model.export( - format="onnx", - imgsz=self.imgsz, - device=self.device, - verbose=False, - ) - elif file.suffix == ".onnx": - model_info = self.get_onnx_model_info(file) - onnx_file = file - else: - continue - - t_engine = self.profile_tensorrt_model(str(engine_file)) - t_onnx = self.profile_onnx_model(str(onnx_file)) - table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info)) - output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info)) - - self.print_table(table_rows) - return output - - def get_files(self): - """ - Return a list of paths for all relevant model files given by the user. - - Returns: - (List[Path]): List of Path objects for the model files. - """ - files = [] - for path in self.paths: - path = Path(path) - if path.is_dir(): - extensions = ["*.pt", "*.onnx", "*.yaml"] - files.extend([file for ext in extensions for file in glob.glob(str(path / ext))]) - elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing - files.append(str(path)) - else: - files.extend(glob.glob(str(path))) - - LOGGER.info(f"Profiling: {sorted(files)}") - return [Path(file) for file in sorted(files)] - - @staticmethod - def get_onnx_model_info(onnx_file: str): - """Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape.""" - return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops) - - @staticmethod - def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3): - """ - Apply iterative sigma clipping to data to remove outliers. - - Args: - data (np.ndarray): Input data array. - sigma (float): Number of standard deviations to use for clipping. - max_iters (int): Maximum number of iterations for the clipping process. - - Returns: - (np.ndarray): Clipped data array with outliers removed. - """ - data = np.array(data) - for _ in range(max_iters): - mean, std = np.mean(data), np.std(data) - clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)] - if len(clipped_data) == len(data): - break - data = clipped_data - return data - - def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3): - """ - Profile YOLO model performance with TensorRT, measuring average run time and standard deviation. - - Args: - engine_file (str): Path to the TensorRT engine file. - eps (float): Small epsilon value to prevent division by zero. - - Returns: - mean_time (float): Mean inference time in milliseconds. - std_time (float): Standard deviation of inference time in milliseconds. - """ - if not self.trt or not Path(engine_file).is_file(): - return 0.0, 0.0 - - # Model and input - model = YOLO(engine_file) - input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify - - # Warmup runs - elapsed = 0.0 - for _ in range(3): - start_time = time.time() - for _ in range(self.num_warmup_runs): - model(input_data, imgsz=self.imgsz, verbose=False) - elapsed = time.time() - start_time - - # Compute number of runs as higher of min_time or num_timed_runs - num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50) - - # Timed runs - run_times = [] - for _ in TQDM(range(num_runs), desc=engine_file): - results = model(input_data, imgsz=self.imgsz, verbose=False) - run_times.append(results[0].speed["inference"]) # Convert to milliseconds - - run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping - return np.mean(run_times), np.std(run_times) - - def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3): - """ - Profile an ONNX model, measuring average inference time and standard deviation across multiple runs. - - Args: - onnx_file (str): Path to the ONNX model file. - eps (float): Small epsilon value to prevent division by zero. - - Returns: - mean_time (float): Mean inference time in milliseconds. - std_time (float): Standard deviation of inference time in milliseconds. - """ - check_requirements("onnxruntime") - import onnxruntime as ort - - # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider' - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - sess_options.intra_op_num_threads = 8 # Limit the number of threads - sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) - - input_tensor = sess.get_inputs()[0] - input_type = input_tensor.type - dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape - input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape - - # Mapping ONNX datatype to numpy datatype - if "float16" in input_type: - input_dtype = np.float16 - elif "float" in input_type: - input_dtype = np.float32 - elif "double" in input_type: - input_dtype = np.float64 - elif "int64" in input_type: - input_dtype = np.int64 - elif "int32" in input_type: - input_dtype = np.int32 - else: - raise ValueError(f"Unsupported ONNX datatype {input_type}") - - input_data = np.random.rand(*input_shape).astype(input_dtype) - input_name = input_tensor.name - output_name = sess.get_outputs()[0].name - - # Warmup runs - elapsed = 0.0 - for _ in range(3): - start_time = time.time() - for _ in range(self.num_warmup_runs): - sess.run([output_name], {input_name: input_data}) - elapsed = time.time() - start_time - - # Compute number of runs as higher of min_time or num_timed_runs - num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs) - - # Timed runs - run_times = [] - for _ in TQDM(range(num_runs), desc=onnx_file): - start_time = time.time() - sess.run([output_name], {input_name: input_data}) - run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds - - run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping - return np.mean(run_times), np.std(run_times) - - def generate_table_row( - self, - model_name: str, - t_onnx: Tuple[float, float], - t_engine: Tuple[float, float], - model_info: Tuple[float, float, float, float], - ): - """ - Generate a table row string with model performance metrics. - - Args: - model_name (str): Name of the model. - t_onnx (tuple): ONNX model inference time statistics (mean, std). - t_engine (tuple): TensorRT engine inference time statistics (mean, std). - model_info (tuple): Model information (layers, params, gradients, flops). - - Returns: - (str): Formatted table row string with model metrics. - """ - layers, params, gradients, flops = model_info - return ( - f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±" - f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |" - ) - - @staticmethod - def generate_results_dict( - model_name: str, - t_onnx: Tuple[float, float], - t_engine: Tuple[float, float], - model_info: Tuple[float, float, float, float], - ): - """ - Generate a dictionary of profiling results. - - Args: - model_name (str): Name of the model. - t_onnx (tuple): ONNX model inference time statistics (mean, std). - t_engine (tuple): TensorRT engine inference time statistics (mean, std). - model_info (tuple): Model information (layers, params, gradients, flops). - - Returns: - (dict): Dictionary containing profiling results. - """ - layers, params, gradients, flops = model_info - return { - "model/name": model_name, - "model/parameters": params, - "model/GFLOPs": round(flops, 3), - "model/speed_ONNX(ms)": round(t_onnx[0], 3), - "model/speed_TensorRT(ms)": round(t_engine[0], 3), - } - - @staticmethod - def print_table(table_rows: List[str]): - """ - Print a formatted table of model profiling results. - - Args: - table_rows (List[str]): List of formatted table row strings. - """ - gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU" - headers = [ - "Model", - "size
(pixels)", - "mAPval
50-95", - f"Speed
CPU ({get_cpu_info()}) ONNX
(ms)", - f"Speed
{gpu} TensorRT
(ms)", - "params
(M)", - "FLOPs
(B)", - ] - header = "|" + "|".join(f" {h} " for h in headers) + "|" - separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|" - - LOGGER.info(f"\n\n{header}") - LOGGER.info(separator) - for row in table_rows: - LOGGER.info(row) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/__init__.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/__init__.py deleted file mode 100644 index 920cc4f..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from .base import add_integration_callbacks, default_callbacks, get_default_callbacks - -__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks" diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/base.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/base.py deleted file mode 100644 index 23be858..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/base.py +++ /dev/null @@ -1,234 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -"""Base callbacks for Ultralytics training, validation, prediction, and export processes.""" - -from collections import defaultdict -from copy import deepcopy - -# Trainer callbacks ---------------------------------------------------------------------------------------------------- - - -def on_pretrain_routine_start(trainer): - """Called before the pretraining routine starts.""" - pass - - -def on_pretrain_routine_end(trainer): - """Called after the pretraining routine ends.""" - pass - - -def on_train_start(trainer): - """Called when the training starts.""" - pass - - -def on_train_epoch_start(trainer): - """Called at the start of each training epoch.""" - pass - - -def on_train_batch_start(trainer): - """Called at the start of each training batch.""" - pass - - -def optimizer_step(trainer): - """Called when the optimizer takes a step.""" - pass - - -def on_before_zero_grad(trainer): - """Called before the gradients are set to zero.""" - pass - - -def on_train_batch_end(trainer): - """Called at the end of each training batch.""" - pass - - -def on_train_epoch_end(trainer): - """Called at the end of each training epoch.""" - pass - - -def on_fit_epoch_end(trainer): - """Called at the end of each fit epoch (train + val).""" - pass - - -def on_model_save(trainer): - """Called when the model is saved.""" - pass - - -def on_train_end(trainer): - """Called when the training ends.""" - pass - - -def on_params_update(trainer): - """Called when the model parameters are updated.""" - pass - - -def teardown(trainer): - """Called during the teardown of the training process.""" - pass - - -# Validator callbacks -------------------------------------------------------------------------------------------------- - - -def on_val_start(validator): - """Called when the validation starts.""" - pass - - -def on_val_batch_start(validator): - """Called at the start of each validation batch.""" - pass - - -def on_val_batch_end(validator): - """Called at the end of each validation batch.""" - pass - - -def on_val_end(validator): - """Called when the validation ends.""" - pass - - -# Predictor callbacks -------------------------------------------------------------------------------------------------- - - -def on_predict_start(predictor): - """Called when the prediction starts.""" - pass - - -def on_predict_batch_start(predictor): - """Called at the start of each prediction batch.""" - pass - - -def on_predict_batch_end(predictor): - """Called at the end of each prediction batch.""" - pass - - -def on_predict_postprocess_end(predictor): - """Called after the post-processing of the prediction ends.""" - pass - - -def on_predict_end(predictor): - """Called when the prediction ends.""" - pass - - -# Exporter callbacks --------------------------------------------------------------------------------------------------- - - -def on_export_start(exporter): - """Called when the model export starts.""" - pass - - -def on_export_end(exporter): - """Called when the model export ends.""" - pass - - -default_callbacks = { - # Run in trainer - "on_pretrain_routine_start": [on_pretrain_routine_start], - "on_pretrain_routine_end": [on_pretrain_routine_end], - "on_train_start": [on_train_start], - "on_train_epoch_start": [on_train_epoch_start], - "on_train_batch_start": [on_train_batch_start], - "optimizer_step": [optimizer_step], - "on_before_zero_grad": [on_before_zero_grad], - "on_train_batch_end": [on_train_batch_end], - "on_train_epoch_end": [on_train_epoch_end], - "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val - "on_model_save": [on_model_save], - "on_train_end": [on_train_end], - "on_params_update": [on_params_update], - "teardown": [teardown], - # Run in validator - "on_val_start": [on_val_start], - "on_val_batch_start": [on_val_batch_start], - "on_val_batch_end": [on_val_batch_end], - "on_val_end": [on_val_end], - # Run in predictor - "on_predict_start": [on_predict_start], - "on_predict_batch_start": [on_predict_batch_start], - "on_predict_postprocess_end": [on_predict_postprocess_end], - "on_predict_batch_end": [on_predict_batch_end], - "on_predict_end": [on_predict_end], - # Run in exporter - "on_export_start": [on_export_start], - "on_export_end": [on_export_end], -} - - -def get_default_callbacks(): - """ - Get the default callbacks for Ultralytics training, validation, prediction, and export processes. - - Returns: - (dict): Dictionary of default callbacks for various training events. Each key represents an event during the - training process, and the corresponding value is a list of callback functions executed when that event - occurs. - - Examples: - >>> callbacks = get_default_callbacks() - >>> print(list(callbacks.keys())) # show all available callback events - ['on_pretrain_routine_start', 'on_pretrain_routine_end', ...] - """ - return defaultdict(list, deepcopy(default_callbacks)) - - -def add_integration_callbacks(instance): - """ - Add integration callbacks to the instance's callbacks dictionary. - - This function loads and adds various integration callbacks to the provided instance. The specific callbacks added - depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive - additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard, - and Weights & Biases. - - Args: - instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added. - The type of instance determines which callbacks are loaded. - - Examples: - >>> from ultralytics.engine.trainer import BaseTrainer - >>> trainer = BaseTrainer() - >>> add_integration_callbacks(trainer) - """ - # Load HUB callbacks - from .hub import callbacks as hub_cb - - callbacks_list = [hub_cb] - - # Load training callbacks - if "Trainer" in instance.__class__.__name__: - from .clearml import callbacks as clear_cb - from .comet import callbacks as comet_cb - from .dvc import callbacks as dvc_cb - from .mlflow import callbacks as mlflow_cb - from .neptune import callbacks as neptune_cb - from .raytune import callbacks as tune_cb - from .tensorboard import callbacks as tb_cb - from .wb import callbacks as wb_cb - - callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb]) - - # Add the callbacks to the callbacks dictionary - for callbacks in callbacks_list: - for k, v in callbacks.items(): - if v not in instance.callbacks[k]: - instance.callbacks[k].append(v) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/clearml.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/clearml.py deleted file mode 100644 index 3f22329..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/clearml.py +++ /dev/null @@ -1,154 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING - -try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS["clearml"] is True # verify integration is enabled - import clearml - from clearml import Task - - assert hasattr(clearml, "__version__") # verify package is not directory - -except (ImportError, AssertionError): - clearml = None - - -def _log_debug_samples(files, title: str = "Debug Samples") -> None: - """ - Log files (images) as debug samples in the ClearML task. - - Args: - files (List[Path]): A list of file paths in PosixPath format. - title (str): A title that groups together images with the same values. - """ - import re - - if task := Task.current_task(): - for f in files: - if f.exists(): - it = re.search(r"_batch(\d+)", f.name) - iteration = int(it.groups()[0]) if it else 0 - task.get_logger().report_image( - title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration - ) - - -def _log_plot(title: str, plot_path: str) -> None: - """ - Log an image as a plot in the plot section of ClearML. - - Args: - title (str): The title of the plot. - plot_path (str): The path to the saved image file. - """ - import matplotlib.image as mpimg - import matplotlib.pyplot as plt - - img = mpimg.imread(plot_path) - fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks - ax.imshow(img) - - Task.current_task().get_logger().report_matplotlib_figure( - title=title, series="", figure=fig, report_interactive=False - ) - - -def on_pretrain_routine_start(trainer) -> None: - """Initialize and connect ClearML task at the start of pretraining routine.""" - try: - if task := Task.current_task(): - # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled! - # We are logging these plots and model files manually in the integration - from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO - from clearml.binding.matplotlib_bind import PatchedMatplotlib - - PatchPyTorchModelIO.update_current_task(None) - PatchedMatplotlib.update_current_task(None) - else: - task = Task.init( - project_name=trainer.args.project or "Ultralytics", - task_name=trainer.args.name, - tags=["Ultralytics"], - output_uri=True, - reuse_last_task_id=False, - auto_connect_frameworks={"pytorch": False, "matplotlib": False}, - ) - LOGGER.warning( - "ClearML Initialized a new task. If you want to run remotely, " - "please add clearml-init and connect your arguments before initializing YOLO." - ) - task.connect(vars(trainer.args), name="General") - except Exception as e: - LOGGER.warning(f"ClearML installed but not initialized correctly, not logging this run. {e}") - - -def on_train_epoch_end(trainer) -> None: - """Log debug samples for the first epoch and report current training progress.""" - if task := Task.current_task(): - # Log debug samples for first epoch only - if trainer.epoch == 1: - _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic") - # Report the current training progress - for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items(): - task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch) - for k, v in trainer.lr.items(): - task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch) - - -def on_fit_epoch_end(trainer) -> None: - """Report model information and metrics to logger at the end of an epoch.""" - if task := Task.current_task(): - # Report epoch time and validation metrics - task.get_logger().report_scalar( - title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch - ) - for k, v in trainer.metrics.items(): - title = k.split("/")[0] - task.get_logger().report_scalar(title, k, v, iteration=trainer.epoch) - if trainer.epoch == 0: - from ultralytics.utils.torch_utils import model_info_for_loggers - - for k, v in model_info_for_loggers(trainer).items(): - task.get_logger().report_single_value(k, v) - - -def on_val_end(validator) -> None: - """Log validation results including labels and predictions.""" - if Task.current_task(): - # Log validation labels and predictions - _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation") - - -def on_train_end(trainer) -> None: - """Log final model and training results on training completion.""" - if task := Task.current_task(): - # Log final results, confusion matrix and PR plots - files = [ - "results.png", - "confusion_matrix.png", - "confusion_matrix_normalized.png", - *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), - ] - files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files - for f in files: - _log_plot(title=f.stem, plot_path=f) - # Report final metrics - for k, v in trainer.validator.metrics.results_dict.items(): - task.get_logger().report_single_value(k, v) - # Log the final model - task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False) - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_train_epoch_end": on_train_epoch_end, - "on_fit_epoch_end": on_fit_epoch_end, - "on_val_end": on_val_end, - "on_train_end": on_train_end, - } - if clearml - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/comet.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/comet.py deleted file mode 100644 index fa484cb..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/comet.py +++ /dev/null @@ -1,639 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from collections.abc import Callable -from types import SimpleNamespace -from typing import Any, List, Optional - -import cv2 -import numpy as np - -from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops -from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics - -try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS["comet"] is True # verify integration is enabled - import comet_ml - - assert hasattr(comet_ml, "__version__") # verify package is not directory - - import os - from pathlib import Path - - # Ensures certain logging functions only run for supported tasks - COMET_SUPPORTED_TASKS = ["detect", "segment"] - - # Names of plots created by Ultralytics that are logged to Comet - CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized" - EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve" - LABEL_PLOT_NAMES = ["labels"] - SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask" - POSE_METRICS_PLOT_PREFIX = "Box", "Pose" - DETECTION_METRICS_PLOT_PREFIX = ["Box"] - RESULTS_TABLE_NAME = "results.csv" - ARGS_YAML_NAME = "args.yaml" - - _comet_image_prediction_count = 0 - -except (ImportError, AssertionError): - comet_ml = None - - -def _get_comet_mode() -> str: - """Return the Comet mode from environment variables, defaulting to 'online'.""" - comet_mode = os.getenv("COMET_MODE") - if comet_mode is not None: - LOGGER.warning( - "The COMET_MODE environment variable is deprecated. " - "Please use COMET_START_ONLINE to set the Comet experiment mode. " - "To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. " - "If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created." - ) - return comet_mode - - return "online" - - -def _get_comet_model_name() -> str: - """Return the Comet model name from environment variable or default to 'Ultralytics'.""" - return os.getenv("COMET_MODEL_NAME", "Ultralytics") - - -def _get_eval_batch_logging_interval() -> int: - """Get the evaluation batch logging interval from environment variable or use default value 1.""" - return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1)) - - -def _get_max_image_predictions_to_log() -> int: - """Get the maximum number of image predictions to log from environment variables.""" - return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100)) - - -def _scale_confidence_score(score: float) -> float: - """Scale the confidence score by a factor specified in environment variable.""" - scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0)) - return score * scale - - -def _should_log_confusion_matrix() -> bool: - """Determine if the confusion matrix should be logged based on environment variable settings.""" - return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true" - - -def _should_log_image_predictions() -> bool: - """Determine whether to log image predictions based on environment variable.""" - return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true" - - -def _resume_or_create_experiment(args: SimpleNamespace) -> None: - """ - Resume CometML experiment or create a new experiment based on args. - - Ensures that the experiment object is only created in a single process during distributed training. - - Args: - args (SimpleNamespace): Training arguments containing project configuration and other parameters. - """ - if RANK not in {-1, 0}: - return - - # Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood. - # IF COMET_START_ONLINE is set by the user it will override COMET_MODE value. - if os.getenv("COMET_START_ONLINE") is None: - comet_mode = _get_comet_mode() - os.environ["COMET_START_ONLINE"] = "1" if comet_mode != "offline" else "0" - - try: - _project_name = os.getenv("COMET_PROJECT_NAME", args.project) - experiment = comet_ml.start(project_name=_project_name) - experiment.log_parameters(vars(args)) - experiment.log_others( - { - "eval_batch_logging_interval": _get_eval_batch_logging_interval(), - "log_confusion_matrix_on_eval": _should_log_confusion_matrix(), - "log_image_predictions": _should_log_image_predictions(), - "max_image_predictions": _get_max_image_predictions_to_log(), - } - ) - experiment.log_other("Created from", "ultralytics") - - except Exception as e: - LOGGER.warning(f"Comet installed but not initialized correctly, not logging this run. {e}") - - -def _fetch_trainer_metadata(trainer) -> dict: - """ - Return metadata for YOLO training including epoch and asset saving status. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config. - - Returns: - (dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag. - """ - curr_epoch = trainer.epoch + 1 - - train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size - curr_step = curr_epoch * train_num_steps_per_epoch - final_epoch = curr_epoch == trainer.epochs - - save = trainer.args.save - save_period = trainer.args.save_period - save_interval = curr_epoch % save_period == 0 - save_assets = save and save_period > 0 and save_interval and not final_epoch - - return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch) - - -def _scale_bounding_box_to_original_image_shape( - box, resized_image_shape, original_image_shape, ratio_pad -) -> List[float]: - """ - Scale bounding box from resized image coordinates to original image coordinates. - - YOLO resizes images during training and the label values are normalized based on this resized shape. - This function rescales the bounding box labels to the original image shape. - - Args: - box (torch.Tensor): Bounding box in normalized xywh format. - resized_image_shape (tuple): Shape of the resized image (height, width). - original_image_shape (tuple): Shape of the original image (height, width). - ratio_pad (tuple): Ratio and padding information for scaling. - - Returns: - (List[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment. - """ - resized_image_height, resized_image_width = resized_image_shape - - # Convert normalized xywh format predictions to xyxy in resized scale format - box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width) - # Scale box predictions from resized image scale back to original image scale - box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad) - # Convert bounding box format from xyxy to xywh for Comet logging - box = ops.xyxy2xywh(box) - # Adjust xy center to correspond top-left corner - box[:2] -= box[2:] / 2 - box = box.tolist() - - return box - - -def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> Optional[dict]: - """ - Format ground truth annotations for object detection. - - This function processes ground truth annotations from a batch of images for object detection tasks. It extracts - bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for - visualization or evaluation. - - Args: - img_idx (int): Index of the image in the batch to process. - image_path (str | Path): Path to the image file. - batch (dict): Batch dictionary containing detection data with keys: - - 'batch_idx': Tensor of batch indices - - 'bboxes': Tensor of bounding boxes in normalized xywh format - - 'cls': Tensor of class labels - - 'ori_shape': Original image shapes - - 'resized_shape': Resized image shapes - - 'ratio_pad': Ratio and padding information - class_name_map (dict, optional): Mapping from class indices to class names. - - Returns: - (dict | None): Formatted ground truth annotations with the following structure: - - 'boxes': List of box coordinates [x, y, width, height] - - 'label': Label string with format "gt_{class_name}" - - 'score': Confidence score (always 1.0, scaled by _scale_confidence_score) - Returns None if no bounding boxes are found for the image. - """ - indices = batch["batch_idx"] == img_idx - bboxes = batch["bboxes"][indices] - if len(bboxes) == 0: - LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes labels") - return None - - cls_labels = batch["cls"][indices].squeeze(1).tolist() - if class_name_map: - cls_labels = [str(class_name_map[label]) for label in cls_labels] - - original_image_shape = batch["ori_shape"][img_idx] - resized_image_shape = batch["resized_shape"][img_idx] - ratio_pad = batch["ratio_pad"][img_idx] - - data = [] - for box, label in zip(bboxes, cls_labels): - box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad) - data.append( - { - "boxes": [box], - "label": f"gt_{label}", - "score": _scale_confidence_score(1.0), - } - ) - - return {"name": "ground_truth", "data": data} - - -def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> Optional[dict]: - """ - Format YOLO predictions for object detection visualization. - - Args: - image_path (Path): Path to the image file. - metadata (dict): Prediction metadata containing bounding boxes and class information. - class_label_map (dict, optional): Mapping from class indices to class names. - class_map (dict, optional): Additional class mapping for label conversion. - - Returns: - (dict | None): Formatted prediction annotations or None if no predictions exist. - """ - stem = image_path.stem - image_id = int(stem) if stem.isnumeric() else stem - - predictions = metadata.get(image_id) - if not predictions: - LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes predictions") - return None - - # apply the mapping that was used to map the predicted classes when the JSON was created - if class_label_map and class_map: - class_label_map = {class_map[k]: v for k, v in class_label_map.items()} - try: - # import pycotools utilities to decompress annotations for various tasks, e.g. segmentation - from faster_coco_eval.core.mask import decode # noqa - except ImportError: - decode = None - - data = [] - for prediction in predictions: - boxes = prediction["bbox"] - score = _scale_confidence_score(prediction["score"]) - cls_label = prediction["category_id"] - if class_label_map: - cls_label = str(class_label_map[cls_label]) - - annotation_data = {"boxes": [boxes], "label": cls_label, "score": score} - - if decode is not None: - # do segmentation processing only if we are able to decode it - segments = prediction.get("segmentation", None) - if segments is not None: - segments = _extract_segmentation_annotation(segments, decode) - if segments is not None: - annotation_data["points"] = segments - - data.append(annotation_data) - - return {"name": "prediction", "data": data} - - -def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> Optional[List[List[Any]]]: - """ - Extract segmentation annotation from compressed segmentations as list of polygons. - - Args: - segmentation_raw (str): Raw segmentation data in compressed format. - decode (Callable): Function to decode the compressed segmentation data. - - Returns: - (List[List[Any]] | None): List of polygon points or None if extraction fails. - """ - try: - mask = decode(segmentation_raw) - contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) - annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3] - return [annotation.ravel().tolist() for annotation in annotations] - except Exception as e: - LOGGER.warning(f"Comet Failed to extract segmentation annotation: {e}") - return None - - -def _fetch_annotations( - img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map -) -> Optional[List]: - """ - Join the ground truth and prediction annotations if they exist. - - Args: - img_idx (int): Index of the image in the batch. - image_path (Path): Path to the image file. - batch (dict): Batch data containing ground truth annotations. - prediction_metadata_map (dict): Map of prediction metadata by image ID. - class_label_map (dict): Mapping from class indices to class names. - class_map (dict): Additional class mapping for label conversion. - - Returns: - (List | None): List of annotation dictionaries or None if no annotations exist. - """ - ground_truth_annotations = _format_ground_truth_annotations_for_detection( - img_idx, image_path, batch, class_label_map - ) - prediction_annotations = _format_prediction_annotations( - image_path, prediction_metadata_map, class_label_map, class_map - ) - - annotations = [ - annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None - ] - return [annotations] if annotations else None - - -def _create_prediction_metadata_map(model_predictions) -> dict: - """Create metadata map for model predictions by grouping them based on image ID.""" - pred_metadata_map = {} - for prediction in model_predictions: - pred_metadata_map.setdefault(prediction["image_id"], []) - pred_metadata_map[prediction["image_id"]].append(prediction) - - return pred_metadata_map - - -def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None: - """Log the confusion matrix to Comet experiment.""" - conf_mat = trainer.validator.confusion_matrix.matrix - names = list(trainer.data["names"].values()) + ["background"] - experiment.log_confusion_matrix( - matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step - ) - - -def _log_images(experiment, image_paths, curr_step: Optional[int], annotations=None) -> None: - """ - Log images to the experiment with optional annotations. - - This function logs images to a Comet ML experiment, optionally including annotation data for visualization - such as bounding boxes or segmentation masks. - - Args: - experiment (comet_ml.CometExperiment): The Comet ML experiment to log images to. - image_paths (List[Path]): List of paths to images that will be logged. - curr_step (int): Current training step/iteration for tracking in the experiment timeline. - annotations (List[List[dict]], optional): Nested list of annotation dictionaries for each image. Each - annotation contains visualization data like bounding boxes, labels, and confidence scores. - """ - if annotations: - for image_path, annotation in zip(image_paths, annotations): - experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation) - - else: - for image_path in image_paths: - experiment.log_image(image_path, name=image_path.stem, step=curr_step) - - -def _log_image_predictions(experiment, validator, curr_step) -> None: - """ - Log predicted boxes for a single image during training. - - This function logs image predictions to a Comet ML experiment during model validation. It processes - validation data and formats both ground truth and prediction annotations for visualization in the Comet - dashboard. The function respects configured limits on the number of images to log. - - Args: - experiment (comet_ml.CometExperiment): The Comet ML experiment to log to. - validator (BaseValidator): The validator instance containing validation data and predictions. - curr_step (int): The current training step for logging timeline. - - Notes: - This function uses global state to track the number of logged predictions across calls. - It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS. - The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable. - """ - global _comet_image_prediction_count - - task = validator.args.task - if task not in COMET_SUPPORTED_TASKS: - return - - jdict = validator.jdict - if not jdict: - return - - predictions_metadata_map = _create_prediction_metadata_map(jdict) - dataloader = validator.dataloader - class_label_map = validator.names - class_map = getattr(validator, "class_map", None) - - batch_logging_interval = _get_eval_batch_logging_interval() - max_image_predictions = _get_max_image_predictions_to_log() - - for batch_idx, batch in enumerate(dataloader): - if (batch_idx + 1) % batch_logging_interval != 0: - continue - - image_paths = batch["im_file"] - for img_idx, image_path in enumerate(image_paths): - if _comet_image_prediction_count >= max_image_predictions: - return - - image_path = Path(image_path) - annotations = _fetch_annotations( - img_idx, - image_path, - batch, - predictions_metadata_map, - class_label_map, - class_map=class_map, - ) - _log_images( - experiment, - [image_path], - curr_step, - annotations=annotations, - ) - _comet_image_prediction_count += 1 - - -def _log_plots(experiment, trainer) -> None: - """ - Log evaluation plots and label plots for the experiment. - - This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles - different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots - for each type. - - Args: - experiment (comet_ml.CometExperiment): The Comet ML experiment to log plots to. - trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save - directory information. - - Examples: - >>> from ultralytics.utils.callbacks.comet import _log_plots - >>> _log_plots(experiment, trainer) - """ - plot_filenames = None - if isinstance(trainer.validator.metrics, SegmentMetrics): - plot_filenames = [ - trainer.save_dir / f"{prefix}{plots}.png" - for plots in EVALUATION_PLOT_NAMES - for prefix in SEGMENT_METRICS_PLOT_PREFIX - ] - elif isinstance(trainer.validator.metrics, PoseMetrics): - plot_filenames = [ - trainer.save_dir / f"{prefix}{plots}.png" - for plots in EVALUATION_PLOT_NAMES - for prefix in POSE_METRICS_PLOT_PREFIX - ] - elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)): - plot_filenames = [ - trainer.save_dir / f"{prefix}{plots}.png" - for plots in EVALUATION_PLOT_NAMES - for prefix in DETECTION_METRICS_PLOT_PREFIX - ] - - if plot_filenames is not None: - _log_images(experiment, plot_filenames, None) - - confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES] - _log_images(experiment, confusion_matrix_filenames, None) - - if not isinstance(trainer.validator.metrics, ClassifyMetrics): - label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES] - _log_images(experiment, label_plot_filenames, None) - - -def _log_model(experiment, trainer) -> None: - """Log the best-trained model to Comet.ml.""" - model_name = _get_comet_model_name() - experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True) - - -def _log_image_batches(experiment, trainer, curr_step: int) -> None: - """Log samples of image batches for train, validation, and test.""" - _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step) - _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step) - - -def _log_asset(experiment, asset_path) -> None: - """ - Logs a specific asset file to the given experiment. - - This function facilitates logging an asset, such as a file, to the provided - experiment. It enables integration with experiment tracking platforms. - - Args: - experiment (comet_ml.CometExperiment): The experiment instance to which the asset will be logged. - asset_path (Path): The file path of the asset to log. - """ - experiment.log_asset(asset_path) - - -def _log_table(experiment, table_path) -> None: - """ - Logs a table to the provided experiment. - - This function is used to log a table file to the given experiment. The table - is identified by its file path. - - Args: - experiment (comet_ml.CometExperiment): The experiment object where the table file will be logged. - table_path (Path): The file path of the table to be logged. - """ - experiment.log_table(str(table_path)) - - -def on_pretrain_routine_start(trainer) -> None: - """Create or resume a CometML experiment at the start of a YOLO pre-training routine.""" - _resume_or_create_experiment(trainer.args) - - -def on_train_epoch_end(trainer) -> None: - """Log metrics and save batch images at the end of training epochs.""" - experiment = comet_ml.get_running_experiment() - if not experiment: - return - - metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata["curr_epoch"] - curr_step = metadata["curr_step"] - - experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch) - - -def on_fit_epoch_end(trainer) -> None: - """ - Log model assets at the end of each epoch during training. - - This function is called at the end of each training epoch to log metrics, learning rates, and model information - to a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on - configuration settings. - - The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch, - it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled), - and image predictions (if enabled). - - Args: - trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration. - - Examples: - >>> # Inside a training loop - >>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML - """ - experiment = comet_ml.get_running_experiment() - if not experiment: - return - - metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata["curr_epoch"] - curr_step = metadata["curr_step"] - save_assets = metadata["save_assets"] - - experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch) - experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch) - if curr_epoch == 1: - from ultralytics.utils.torch_utils import model_info_for_loggers - - experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch) - - if not save_assets: - return - - _log_model(experiment, trainer) - if _should_log_confusion_matrix(): - _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) - if _should_log_image_predictions(): - _log_image_predictions(experiment, trainer.validator, curr_step) - - -def on_train_end(trainer) -> None: - """Perform operations at the end of training.""" - experiment = comet_ml.get_running_experiment() - if not experiment: - return - - metadata = _fetch_trainer_metadata(trainer) - curr_epoch = metadata["curr_epoch"] - curr_step = metadata["curr_step"] - plots = trainer.args.plots - - _log_model(experiment, trainer) - if plots: - _log_plots(experiment, trainer) - - _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) - _log_image_predictions(experiment, trainer.validator, curr_step) - _log_image_batches(experiment, trainer, curr_step) - # log results table - table_path = trainer.save_dir / RESULTS_TABLE_NAME - if table_path.exists(): - _log_table(experiment, table_path) - - # log arguments YAML - args_path = trainer.save_dir / ARGS_YAML_NAME - if args_path.exists(): - _log_asset(experiment, args_path) - - experiment.end() - - global _comet_image_prediction_count - _comet_image_prediction_count = 0 - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_train_epoch_end": on_train_epoch_end, - "on_fit_epoch_end": on_fit_epoch_end, - "on_train_end": on_train_end, - } - if comet_ml - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/dvc.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/dvc.py deleted file mode 100644 index 35a16d7..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/dvc.py +++ /dev/null @@ -1,202 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from pathlib import Path - -from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks - -try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS["dvc"] is True # verify integration is enabled - import dvclive - - assert checks.check_version("dvclive", "2.11.0", verbose=True) - - import os - import re - - # DVCLive logger instance - live = None - _processed_plots = {} - - # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we - # distinguish final evaluation of the best model vs last epoch validation - _training_epoch = False - -except (ImportError, AssertionError, TypeError): - dvclive = None - - -def _log_images(path: Path, prefix: str = "") -> None: - """ - Log images at specified path with an optional prefix using DVCLive. - - This function logs images found at the given path to DVCLive, organizing them by batch to enable slider - functionality in the UI. It processes image filenames to extract batch information and restructures the path - accordingly. - - Args: - path (Path): Path to the image file to be logged. - prefix (str, optional): Optional prefix to add to the image name when logging. - - Examples: - >>> from pathlib import Path - >>> _log_images(Path("runs/train/exp/val_batch0_pred.jpg"), prefix="validation") - """ - if live: - name = path.name - - # Group images by batch to enable sliders in UI - if m := re.search(r"_batch(\d+)", name): - ni = m[1] - new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem) - name = (Path(new_stem) / ni).with_suffix(path.suffix) - - live.log_image(os.path.join(prefix, name), path) - - -def _log_plots(plots: dict, prefix: str = "") -> None: - """ - Log plot images for training progress if they have not been previously processed. - - Args: - plots (dict): Dictionary containing plot information with timestamps. - prefix (str, optional): Optional prefix to add to the logged image paths. - """ - for name, params in plots.items(): - timestamp = params["timestamp"] - if _processed_plots.get(name) != timestamp: - _log_images(name, prefix) - _processed_plots[name] = timestamp - - -def _log_confusion_matrix(validator) -> None: - """ - Log confusion matrix for a validator using DVCLive. - - This function processes the confusion matrix from a validator object and logs it to DVCLive by converting - the matrix into lists of target and prediction labels. - - Args: - validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have - attributes: confusion_matrix.matrix, confusion_matrix.task, and names. - """ - targets = [] - preds = [] - matrix = validator.confusion_matrix.matrix - names = list(validator.names.values()) - if validator.confusion_matrix.task == "detect": - names += ["background"] - - for ti, pred in enumerate(matrix.T.astype(int)): - for pi, num in enumerate(pred): - targets.extend([names[ti]] * num) - preds.extend([names[pi]] * num) - - live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True) - - -def on_pretrain_routine_start(trainer) -> None: - """Initialize DVCLive logger for training metadata during pre-training routine.""" - try: - global live - live = dvclive.Live(save_dvc_exp=True, cache_images=True) - LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).") - except Exception as e: - LOGGER.warning(f"DVCLive installed but not initialized correctly, not logging this run. {e}") - - -def on_pretrain_routine_end(trainer) -> None: - """Log plots related to the training process at the end of the pretraining routine.""" - _log_plots(trainer.plots, "train") - - -def on_train_start(trainer) -> None: - """Log the training parameters if DVCLive logging is active.""" - if live: - live.log_params(trainer.args) - - -def on_train_epoch_start(trainer) -> None: - """Set the global variable _training_epoch value to True at the start of training each epoch.""" - global _training_epoch - _training_epoch = True - - -def on_fit_epoch_end(trainer) -> None: - """ - Log training metrics, model info, and advance to next step at the end of each fit epoch. - - This function is called at the end of each fit epoch during training. It logs various metrics including - training loss items, validation metrics, and learning rates. On the first epoch, it also logs model - information. Additionally, it logs training and validation plots and advances the DVCLive step counter. - - Args: - trainer (BaseTrainer): The trainer object containing training state, metrics, and plots. - - Notes: - This function only performs logging operations when DVCLive logging is active and during a training epoch. - The global variable _training_epoch is used to track whether the current epoch is a training epoch. - """ - global _training_epoch - if live and _training_epoch: - all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} - for metric, value in all_metrics.items(): - live.log_metric(metric, value) - - if trainer.epoch == 0: - from ultralytics.utils.torch_utils import model_info_for_loggers - - for metric, value in model_info_for_loggers(trainer).items(): - live.log_metric(metric, value, plot=False) - - _log_plots(trainer.plots, "train") - _log_plots(trainer.validator.plots, "val") - - live.next_step() - _training_epoch = False - - -def on_train_end(trainer) -> None: - """ - Log best metrics, plots, and confusion matrix at the end of training. - - This function is called at the conclusion of the training process to log final metrics, visualizations, and - model artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots, - validation plots, and confusion matrix for later analysis. - - Args: - trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results. - - Examples: - >>> # Inside a custom training loop - >>> from ultralytics.utils.callbacks.dvc import on_train_end - >>> on_train_end(trainer) # Log final metrics and artifacts - """ - if live: - # At the end log the best metrics. It runs validator on the best model internally. - all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr} - for metric, value in all_metrics.items(): - live.log_metric(metric, value, plot=False) - - _log_plots(trainer.plots, "val") - _log_plots(trainer.validator.plots, "val") - _log_confusion_matrix(trainer.validator) - - if trainer.best.exists(): - live.log_artifact(trainer.best, copy=True, type="model") - - live.end() - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_pretrain_routine_end": on_pretrain_routine_end, - "on_train_start": on_train_start, - "on_train_epoch_start": on_train_epoch_start, - "on_fit_epoch_end": on_fit_epoch_end, - "on_train_end": on_train_end, - } - if dvclive - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/hub.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/hub.py deleted file mode 100644 index fc81b7d..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/hub.py +++ /dev/null @@ -1,109 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import json -from time import time - -from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession, events -from ultralytics.utils import LOGGER, RANK, SETTINGS - - -def on_pretrain_routine_start(trainer): - """Create a remote Ultralytics HUB session to log local model training.""" - if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None: - trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args) - - -def on_pretrain_routine_end(trainer): - """Initialize timers for upload rate limiting before training begins.""" - if session := getattr(trainer, "hub_session", None): - # Start timer for upload rate limit - session.timers = {"metrics": time(), "ckpt": time()} # start timer for session rate limiting - - -def on_fit_epoch_end(trainer): - """Upload training progress metrics to Ultralytics HUB at the end of each epoch.""" - if session := getattr(trainer, "hub_session", None): - # Upload metrics after validation ends - all_plots = { - **trainer.label_loss_items(trainer.tloss, prefix="train"), - **trainer.metrics, - } - if trainer.epoch == 0: - from ultralytics.utils.torch_utils import model_info_for_loggers - - all_plots = {**all_plots, **model_info_for_loggers(trainer)} - - session.metrics_queue[trainer.epoch] = json.dumps(all_plots) - - # If any metrics failed to upload previously, add them to the queue to attempt uploading again - if session.metrics_upload_failed_queue: - session.metrics_queue.update(session.metrics_upload_failed_queue) - - if time() - session.timers["metrics"] > session.rate_limits["metrics"]: - session.upload_metrics() - session.timers["metrics"] = time() # reset timer - session.metrics_queue = {} # reset queue - - -def on_model_save(trainer): - """Upload model checkpoints to Ultralytics HUB with rate limiting.""" - if session := getattr(trainer, "hub_session", None): - # Upload checkpoints with rate limiting - is_best = trainer.best_fitness == trainer.fitness - if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]: - LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}") - session.upload_model(trainer.epoch, trainer.last, is_best) - session.timers["ckpt"] = time() # reset timer - - -def on_train_end(trainer): - """Upload final model and metrics to Ultralytics HUB at the end of training.""" - if session := getattr(trainer, "hub_session", None): - # Upload final model and metrics with exponential standoff - LOGGER.info(f"{PREFIX}Syncing final model...") - session.upload_model( - trainer.epoch, - trainer.best, - map=trainer.metrics.get("metrics/mAP50-95(B)", 0), - final=True, - ) - session.alive = False # stop heartbeats - LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀") - - -def on_train_start(trainer): - """Run events on train start.""" - events(trainer.args, trainer.device) - - -def on_val_start(validator): - """Run events on validation start.""" - if not validator.training: - events(validator.args, validator.device) - - -def on_predict_start(predictor): - """Run events on predict start.""" - events(predictor.args, predictor.device) - - -def on_export_start(exporter): - """Run events on export start.""" - events(exporter.args, exporter.device) - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_pretrain_routine_end": on_pretrain_routine_end, - "on_fit_epoch_end": on_fit_epoch_end, - "on_model_save": on_model_save, - "on_train_end": on_train_end, - "on_train_start": on_train_start, - "on_val_start": on_val_start, - "on_predict_start": on_predict_start, - "on_export_start": on_export_start, - } - if SETTINGS["hub"] is True - else {} -) # verify hub is enabled before registering callbacks diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/mlflow.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/mlflow.py deleted file mode 100644 index f570240..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/mlflow.py +++ /dev/null @@ -1,135 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -""" -MLflow Logging for Ultralytics YOLO. - -This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts. -For setting up, a tracking URI should be specified. The logging can be customized using environment variables. - -Commands: - 1. To set a project name: - `export MLFLOW_EXPERIMENT_NAME=` or use the project= argument - - 2. To set a run name: - `export MLFLOW_RUN=` or use the name= argument - - 3. To start a local MLflow server: - mlflow server --backend-store-uri runs/mlflow - It will by default start a local server at http://127.0.0.1:5000. - To specify a different URI, set the MLFLOW_TRACKING_URI environment variable. - - 4. To kill all running MLflow server instances: - ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9 -""" - -from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr - -try: - import os - - assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest - assert SETTINGS["mlflow"] is True # verify integration is enabled - import mlflow - - assert hasattr(mlflow, "__version__") # verify package is not directory - from pathlib import Path - - PREFIX = colorstr("MLflow: ") - -except (ImportError, AssertionError): - mlflow = None - - -def sanitize_dict(x: dict) -> dict: - """Sanitize dictionary keys by removing parentheses and converting values to floats.""" - return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()} - - -def on_pretrain_routine_end(trainer): - """ - Log training parameters to MLflow at the end of the pretraining routine. - - This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI, - experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters - from the trainer. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log. - - Environment Variables: - MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'. - MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project. - MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name. - MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends. - """ - global mlflow - - uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow") - LOGGER.debug(f"{PREFIX} tracking uri: {uri}") - mlflow.set_tracking_uri(uri) - - # Set experiment and run names - experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics" - run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name - mlflow.set_experiment(experiment_name) - - mlflow.autolog() - try: - active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name) - LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}") - if Path(uri).is_dir(): - LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'") - LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'") - mlflow.log_params(dict(trainer.args)) - except Exception as e: - LOGGER.warning(f"{PREFIX}Failed to initialize: {e}") - LOGGER.warning(f"{PREFIX}Not tracking this run") - - -def on_train_epoch_end(trainer): - """Log training metrics at the end of each train epoch to MLflow.""" - if mlflow: - mlflow.log_metrics( - metrics={ - **sanitize_dict(trainer.lr), - **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")), - }, - step=trainer.epoch, - ) - - -def on_fit_epoch_end(trainer): - """Log training metrics at the end of each fit epoch to MLflow.""" - if mlflow: - mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch) - - -def on_train_end(trainer): - """Log model artifacts at the end of training.""" - if not mlflow: - return - mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt - for f in trainer.save_dir.glob("*"): # log all other files in save_dir - if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}: - mlflow.log_artifact(str(f)) - keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true" - if keep_run_active: - LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()") - else: - mlflow.end_run() - LOGGER.debug(f"{PREFIX}mlflow run ended") - - LOGGER.info( - f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'" - ) - - -callbacks = ( - { - "on_pretrain_routine_end": on_pretrain_routine_end, - "on_train_epoch_end": on_train_epoch_end, - "on_fit_epoch_end": on_fit_epoch_end, - "on_train_end": on_train_end, - } - if mlflow - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/neptune.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/neptune.py deleted file mode 100644 index b27964b..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/neptune.py +++ /dev/null @@ -1,134 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING - -try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS["neptune"] is True # verify integration is enabled - - import neptune - from neptune.types import File - - assert hasattr(neptune, "__version__") - - run = None # NeptuneAI experiment logger instance - -except (ImportError, AssertionError): - neptune = None - - -def _log_scalars(scalars: dict, step: int = 0) -> None: - """ - Log scalars to the NeptuneAI experiment logger. - - Args: - scalars (dict): Dictionary of scalar values to log to NeptuneAI. - step (int, optional): The current step or iteration number for logging. - - Examples: - >>> metrics = {"mAP": 0.85, "loss": 0.32} - >>> _log_scalars(metrics, step=100) - """ - if run: - for k, v in scalars.items(): - run[k].append(value=v, step=step) - - -def _log_images(imgs_dict: dict, group: str = "") -> None: - """ - Log images to the NeptuneAI experiment logger. - - This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized - under the specified group name. - - Args: - imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data. - group (str, optional): Group name to organize images under in the Neptune UI. - - Examples: - >>> # Log validation images - >>> _log_images({"val_batch": img_tensor}, group="validation") - """ - if run: - for k, v in imgs_dict.items(): - run[f"{group}/{k}"].upload(File(v)) - - -def _log_plot(title: str, plot_path: str) -> None: - """Log plots to the NeptuneAI experiment logger.""" - import matplotlib.image as mpimg - import matplotlib.pyplot as plt - - img = mpimg.imread(plot_path) - fig = plt.figure() - ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks - ax.imshow(img) - run[f"Plots/{title}"].upload(fig) - - -def on_pretrain_routine_start(trainer) -> None: - """Initialize NeptuneAI run and log hyperparameters before training starts.""" - try: - global run - run = neptune.init_run( - project=trainer.args.project or "Ultralytics", - name=trainer.args.name, - tags=["Ultralytics"], - ) - run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()} - except Exception as e: - LOGGER.warning(f"NeptuneAI installed but not initialized correctly, not logging this run. {e}") - - -def on_train_epoch_end(trainer) -> None: - """Log training metrics and learning rate at the end of each training epoch.""" - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) - _log_scalars(trainer.lr, trainer.epoch + 1) - if trainer.epoch == 1: - _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic") - - -def on_fit_epoch_end(trainer) -> None: - """Log model info and validation metrics at the end of each fit epoch.""" - if run and trainer.epoch == 0: - from ultralytics.utils.torch_utils import model_info_for_loggers - - run["Configuration/Model"] = model_info_for_loggers(trainer) - _log_scalars(trainer.metrics, trainer.epoch + 1) - - -def on_val_end(validator) -> None: - """Log validation images at the end of validation.""" - if run: - # Log val_labels and val_pred - _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation") - - -def on_train_end(trainer) -> None: - """Log final results, plots, and model weights at the end of training.""" - if run: - # Log final results, CM matrix + PR plots - files = [ - "results.png", - "confusion_matrix.png", - "confusion_matrix_normalized.png", - *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")), - ] - files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter - for f in files: - _log_plot(title=f.stem, plot_path=f) - # Log the final model - run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best))) - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_train_epoch_end": on_train_epoch_end, - "on_fit_epoch_end": on_fit_epoch_end, - "on_val_end": on_val_end, - "on_train_end": on_train_end, - } - if neptune - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/raytune.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/raytune.py deleted file mode 100644 index 4a75a70..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/raytune.py +++ /dev/null @@ -1,43 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.utils import SETTINGS - -try: - assert SETTINGS["raytune"] is True # verify integration is enabled - import ray - from ray import tune - from ray.air import session - -except (ImportError, AssertionError): - tune = None - - -def on_fit_epoch_end(trainer): - """ - Report training metrics to Ray Tune at epoch end when a Ray session is active. - - Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number, - enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs. - - Examples: - >>> # Called automatically by the Ultralytics training loop - >>> on_fit_epoch_end(trainer) - - References: - Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html - """ - if ray.train._internal.session.get_session(): # check if Ray Tune session is active - metrics = trainer.metrics - session.report({**metrics, **{"epoch": trainer.epoch + 1}}) - - -callbacks = ( - { - "on_fit_epoch_end": on_fit_epoch_end, - } - if tune - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/tensorboard.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/tensorboard.py deleted file mode 100644 index f5adc31..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/tensorboard.py +++ /dev/null @@ -1,131 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils - -try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS["tensorboard"] is True # verify integration is enabled - WRITER = None # TensorBoard SummaryWriter instance - PREFIX = colorstr("TensorBoard: ") - - # Imports below only required if TensorBoard enabled - import warnings - from copy import deepcopy - - import torch - from torch.utils.tensorboard import SummaryWriter - -except (ImportError, AssertionError, TypeError, AttributeError): - # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows - # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed - SummaryWriter = None - - -def _log_scalars(scalars: dict, step: int = 0) -> None: - """ - Log scalar values to TensorBoard. - - Args: - scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the - corresponding scalar values. - step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs. - - Examples: - Log training metrics - >>> metrics = {"loss": 0.5, "accuracy": 0.95} - >>> _log_scalars(metrics, step=100) - """ - if WRITER: - for k, v in scalars.items(): - WRITER.add_scalar(k, v, step) - - -def _log_tensorboard_graph(trainer) -> None: - """ - Log model graph to TensorBoard. - - This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input - tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex - approach for models like RTDETR that may require special handling. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize. - Must have attributes model and args with imgsz. - - Notes: - This function requires TensorBoard integration to be enabled and the global WRITER to be initialized. - It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different - model architectures. - """ - # Input image - imgsz = trainer.args.imgsz - imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz - p = next(trainer.model.parameters()) # for device, type - im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning - warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning - - # Try simple method first (YOLO) - try: - trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes - WRITER.add_graph(torch.jit.trace(torch_utils.de_parallel(trainer.model), im, strict=False), []) - LOGGER.info(f"{PREFIX}model graph visualization added ✅") - return - - except Exception: - # Fallback to TorchScript export steps (RTDETR) - try: - model = deepcopy(torch_utils.de_parallel(trainer.model)) - model.eval() - model = model.fuse(verbose=False) - for m in model.modules(): - if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class) - m.export = True - m.format = "torchscript" - model(im) # dry run - WRITER.add_graph(torch.jit.trace(model, im, strict=False), []) - LOGGER.info(f"{PREFIX}model graph visualization added ✅") - except Exception as e: - LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure {e}") - - -def on_pretrain_routine_start(trainer) -> None: - """Initialize TensorBoard logging with SummaryWriter.""" - if SummaryWriter: - try: - global WRITER - WRITER = SummaryWriter(str(trainer.save_dir)) - LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/") - except Exception as e: - LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}") - - -def on_train_start(trainer) -> None: - """Log TensorBoard graph.""" - if WRITER: - _log_tensorboard_graph(trainer) - - -def on_train_epoch_end(trainer) -> None: - """Log scalar statistics at the end of a training epoch.""" - _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1) - _log_scalars(trainer.lr, trainer.epoch + 1) - - -def on_fit_epoch_end(trainer) -> None: - """Log epoch metrics at end of training epoch.""" - _log_scalars(trainer.metrics, trainer.epoch + 1) - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_train_start": on_train_start, - "on_fit_epoch_end": on_fit_epoch_end, - "on_train_epoch_end": on_train_epoch_end, - } - if SummaryWriter - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/wb.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/wb.py deleted file mode 100644 index 319aa68..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/callbacks/wb.py +++ /dev/null @@ -1,185 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.utils import SETTINGS, TESTS_RUNNING -from ultralytics.utils.torch_utils import model_info_for_loggers - -try: - assert not TESTS_RUNNING # do not log pytest - assert SETTINGS["wandb"] is True # verify integration is enabled - import wandb as wb - - assert hasattr(wb, "__version__") # verify package is not directory - _processed_plots = {} - -except (ImportError, AssertionError): - wb = None - - -def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"): - """ - Create and log a custom metric visualization to wandb.plot.pr_curve. - - This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall - curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across - different classes. - - Args: - x (list): Values for the x-axis; expected to have length N. - y (list): Corresponding values for the y-axis; also expected to have length N. - classes (list): Labels identifying the class of each point; length N. - title (str, optional): Title for the plot. - x_title (str, optional): Label for the x-axis. - y_title (str, optional): Label for the y-axis. - - Returns: - (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization. - """ - import pandas # scope for faster 'import ultralytics' - - df = pandas.DataFrame({"class": classes, "y": y, "x": x}).round(3) - fields = {"x": "x", "y": "y", "class": "class"} - string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title} - return wb.plot_table( - "wandb/area-under-curve/v0", wb.Table(dataframe=df), fields=fields, string_fields=string_fields - ) - - -def _plot_curve( - x, - y, - names=None, - id="precision-recall", - title="Precision Recall Curve", - x_title="Recall", - y_title="Precision", - num_x=100, - only_mean=False, -): - """ - Log a metric curve visualization. - - This function generates a metric curve based on input data and logs the visualization to wandb. - The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag. - - Args: - x (np.ndarray): Data points for the x-axis with length N. - y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes. - names (list, optional): Names of the classes corresponding to the y-axis data; length C. - id (str, optional): Unique identifier for the logged data in wandb. - title (str, optional): Title for the visualization plot. - x_title (str, optional): Label for the x-axis. - y_title (str, optional): Label for the y-axis. - num_x (int, optional): Number of interpolated data points for visualization. - only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted. - - Notes: - The function leverages the '_custom_table' function to generate the actual visualization. - """ - import numpy as np - - # Create new x - if names is None: - names = [] - x_new = np.linspace(x[0], x[-1], num_x).round(5) - - # Create arrays for logging - x_log = x_new.tolist() - y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist() - - if only_mean: - table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title]) - wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)}) - else: - classes = ["mean"] * len(x_log) - for i, yi in enumerate(y): - x_log.extend(x_new) # add new x - y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x - classes.extend([names[i]] * len(x_new)) # add class names - wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False) - - -def _log_plots(plots, step): - """ - Log plots to WandB at a specific step if they haven't been logged already. - - This function checks each plot in the input dictionary against previously processed plots and logs - new or updated plots to WandB at the specified step. - - Args: - plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries - containing plot metadata including timestamps. - step (int): The step/epoch at which to log the plots in the WandB run. - - Notes: - The function uses a shallow copy of the plots dictionary to prevent modification during iteration. - Plots are identified by their stem name (filename without extension). - Each plot is logged as a WandB Image object. - """ - for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration - timestamp = params["timestamp"] - if _processed_plots.get(name) != timestamp: - wb.run.log({name.stem: wb.Image(str(name))}, step=step) - _processed_plots[name] = timestamp - - -def on_pretrain_routine_start(trainer): - """Initialize and start wandb project if module is present.""" - if not wb.run: - wb.init( - project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics", - name=str(trainer.args.name).replace("/", "-"), - config=vars(trainer.args), - ) - - -def on_fit_epoch_end(trainer): - """Log training metrics and model information at the end of an epoch.""" - wb.run.log(trainer.metrics, step=trainer.epoch + 1) - _log_plots(trainer.plots, step=trainer.epoch + 1) - _log_plots(trainer.validator.plots, step=trainer.epoch + 1) - if trainer.epoch == 0: - wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1) - - -def on_train_epoch_end(trainer): - """Log metrics and save images at the end of each training epoch.""" - wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1) - wb.run.log(trainer.lr, step=trainer.epoch + 1) - if trainer.epoch == 1: - _log_plots(trainer.plots, step=trainer.epoch + 1) - - -def on_train_end(trainer): - """Save the best model as an artifact and log final plots at the end of training.""" - _log_plots(trainer.validator.plots, step=trainer.epoch + 1) - _log_plots(trainer.plots, step=trainer.epoch + 1) - art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model") - if trainer.best.exists(): - art.add_file(trainer.best) - wb.run.log_artifact(art, aliases=["best"]) - # Check if we actually have plots to save - if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"): - for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results): - x, y, x_title, y_title = curve_values - _plot_curve( - x, - y, - names=list(trainer.validator.metrics.names.values()), - id=f"curves/{curve_name}", - title=curve_name, - x_title=x_title, - y_title=y_title, - ) - wb.run.finish() # required or run continues on dashboard - - -callbacks = ( - { - "on_pretrain_routine_start": on_pretrain_routine_start, - "on_train_epoch_end": on_train_epoch_end, - "on_fit_epoch_end": on_fit_epoch_end, - "on_train_end": on_train_end, - } - if wb - else {} -) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/checks.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/checks.py deleted file mode 100644 index 72421b9..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/checks.py +++ /dev/null @@ -1,943 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import functools -import glob -import inspect -import math -import os -import platform -import re -import shutil -import subprocess -import time -from importlib import metadata -from pathlib import Path -from types import SimpleNamespace -from typing import Optional - -import cv2 -import numpy as np -import torch - -from ultralytics.utils import ( - ARM64, - ASSETS, - AUTOINSTALL, - IS_COLAB, - IS_GIT_DIR, - IS_JETSON, - IS_KAGGLE, - IS_PIP_PACKAGE, - LINUX, - LOGGER, - MACOS, - ONLINE, - PYTHON_VERSION, - RKNN_CHIPS, - ROOT, - TORCHVISION_VERSION, - USER_CONFIG_DIR, - WINDOWS, - Retry, - ThreadingLocked, - TryExcept, - clean_url, - colorstr, - downloads, - is_github_action_running, - url2file, -) - - -def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""): - """ - Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'. - - Args: - file_path (Path): Path to the requirements.txt file. - package (str, optional): Python package to use instead of requirements.txt file. - - Returns: - requirements (List[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and - `specifier` attributes. - - Examples: - >>> from ultralytics.utils.checks import parse_requirements - >>> parse_requirements(package="ultralytics") - """ - if package: - requires = [x for x in metadata.distribution(package).requires if "extra == " not in x] - else: - requires = Path(file_path).read_text().splitlines() - - requirements = [] - for line in requires: - line = line.strip() - if line and not line.startswith("#"): - line = line.partition("#")[0].strip() # ignore inline comments - if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line): - requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else "")) - - return requirements - - -@functools.lru_cache -def parse_version(version="0.0.0") -> tuple: - """ - Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version. - - Args: - version (str): Version string, i.e. '2.0.1+cpu' - - Returns: - (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1) - """ - try: - return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1) - except Exception as e: - LOGGER.warning(f"failure for parse_version({version}), returning (0, 0, 0): {e}") - return 0, 0, 0 - - -def is_ascii(s) -> bool: - """ - Check if a string is composed of only ASCII characters. - - Args: - s (str | list | tuple | dict): Input to be checked (all are converted to string for checking). - - Returns: - (bool): True if the string is composed only of ASCII characters, False otherwise. - """ - return all(ord(c) < 128 for c in str(s)) - - -def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0): - """ - Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the - stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value. - - Args: - imgsz (int | List[int]): Image size. - stride (int): Stride value. - min_dim (int): Minimum number of dimensions. - max_dim (int): Maximum number of dimensions. - floor (int): Minimum allowed value for image size. - - Returns: - (List[int] | int): Updated image size. - """ - # Convert stride to integer if it is a tensor - stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride) - - # Convert image size to list if it is an integer - if isinstance(imgsz, int): - imgsz = [imgsz] - elif isinstance(imgsz, (list, tuple)): - imgsz = list(imgsz) - elif isinstance(imgsz, str): # i.e. '640' or '[640,640]' - imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz) - else: - raise TypeError( - f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. " - f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'" - ) - - # Apply max_dim - if len(imgsz) > max_dim: - msg = ( - "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list " - "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'" - ) - if max_dim != 1: - raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}") - LOGGER.warning(f"updating to 'imgsz={max(imgsz)}'. {msg}") - imgsz = [max(imgsz)] - # Make image size a multiple of the stride - sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz] - - # Print warning message if image size was updated - if sz != imgsz: - LOGGER.warning(f"imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}") - - # Add missing dimensions if necessary - sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz - - return sz - - -@functools.lru_cache -def check_uv(): - """Check if uv package manager is installed and can run successfully.""" - try: - return subprocess.run(["uv", "-V"], capture_output=True).returncode == 0 - except FileNotFoundError: - return False - - -@functools.lru_cache -def check_version( - current: str = "0.0.0", - required: str = "0.0.0", - name: str = "version", - hard: bool = False, - verbose: bool = False, - msg: str = "", -) -> bool: - """ - Check current version against the required version or range. - - Args: - current (str): Current version or package name to get version from. - required (str): Required version or range (in pip-style format). - name (str): Name to be used in warning message. - hard (bool): If True, raise an AssertionError if the requirement is not met. - verbose (bool): If True, print warning message if requirement is not met. - msg (str): Extra message to display if verbose. - - Returns: - (bool): True if requirement is met, False otherwise. - - Examples: - Check if current version is exactly 22.04 - >>> check_version(current="22.04", required="==22.04") - - Check if current version is greater than or equal to 22.04 - >>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed - - Check if current version is less than or equal to 22.04 - >>> check_version(current="22.04", required="<=22.04") - - Check if current version is between 20.04 (inclusive) and 22.04 (exclusive) - >>> check_version(current="21.10", required=">20.04,<22.04") - """ - if not current: # if current is '' or None - LOGGER.warning(f"invalid check_version({current}, {required}) requested, please check values.") - return True - elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics' - try: - name = current # assigned package name to 'name' arg - current = metadata.version(current) # get version string from package name - except metadata.PackageNotFoundError as e: - if hard: - raise ModuleNotFoundError(f"{current} package is required but not installed") from e - else: - return False - - if not required: # if required is '' or None - return True - - if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"' - (WINDOWS and "win32" not in required) - or (LINUX and "linux" not in required) - or (MACOS and "macos" not in required and "darwin" not in required) - ): - return True - - op = "" - version = "" - result = True - c = parse_version(current) # '1.2.3' -> (1, 2, 3) - for r in required.strip(",").split(","): - op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04') - if not op: - op = ">=" # assume >= if no op passed - v = parse_version(version) # '1.2.3' -> (1, 2, 3) - if op == "==" and c != v: - result = False - elif op == "!=" and c == v: - result = False - elif op == ">=" and not (c >= v): - result = False - elif op == "<=" and not (c <= v): - result = False - elif op == ">" and not (c > v): - result = False - elif op == "<" and not (c < v): - result = False - if not result: - warning = f"{name}{required} is required, but {name}=={current} is currently installed {msg}" - if hard: - raise ModuleNotFoundError(warning) # assert version requirements met - if verbose: - LOGGER.warning(warning) - return result - - -def check_latest_pypi_version(package_name="ultralytics"): - """ - Return the latest version of a PyPI package without downloading or installing it. - - Args: - package_name (str): The name of the package to find the latest version for. - - Returns: - (str): The latest version of the package. - """ - import requests # slow import - - try: - requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning - response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3) - if response.status_code == 200: - return response.json()["info"]["version"] - except Exception: - return None - - -def check_pip_update_available(): - """ - Check if a new version of the ultralytics package is available on PyPI. - - Returns: - (bool): True if an update is available, False otherwise. - """ - if ONLINE and IS_PIP_PACKAGE: - try: - from ultralytics import __version__ - - latest = check_latest_pypi_version() - if check_version(__version__, f"<{latest}"): # check if current version is < latest version - LOGGER.info( - f"New https://pypi.org/project/ultralytics/{latest} available 😃 " - f"Update with 'pip install -U ultralytics'" - ) - return True - except Exception: - pass - return False - - -@ThreadingLocked() -@functools.lru_cache -def check_font(font="Arial.ttf"): - """ - Find font locally or download to user's configuration directory if it does not already exist. - - Args: - font (str): Path or name of font. - - Returns: - (Path): Resolved font file path. - """ - from matplotlib import font_manager # scope for faster 'import ultralytics' - - # Check USER_CONFIG_DIR - name = Path(font).name - file = USER_CONFIG_DIR / name - if file.exists(): - return file - - # Check system fonts - matches = [s for s in font_manager.findSystemFonts() if font in s] - if any(matches): - return matches[0] - - # Download to USER_CONFIG_DIR if missing - url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}" - if downloads.is_url(url, check=True): - downloads.safe_download(url=url, file=file) - return file - - -def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool: - """ - Check current python version against the required minimum version. - - Args: - minimum (str): Required minimum version of python. - hard (bool): If True, raise an AssertionError if the requirement is not met. - verbose (bool): If True, print warning message if requirement is not met. - - Returns: - (bool): Whether the installed Python version meets the minimum constraints. - """ - return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose) - - -@TryExcept() -def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""): - """ - Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed. - - Args: - requirements (Path | str | List[str]): Path to a requirements.txt file, a single package requirement as a - string, or a list of package requirements as strings. - exclude (tuple): Tuple of package names to exclude from checking. - install (bool): If True, attempt to auto-update packages that don't meet requirements. - cmds (str): Additional commands to pass to the pip install command when auto-updating. - - Examples: - >>> from ultralytics.utils.checks import check_requirements - - Check a requirements.txt file - >>> check_requirements("path/to/requirements.txt") - - Check a single package - >>> check_requirements("ultralytics>=8.0.0") - - Check multiple packages - >>> check_requirements(["numpy", "ultralytics>=8.0.0"]) - """ - prefix = colorstr("red", "bold", "requirements:") - if isinstance(requirements, Path): # requirements.txt file - file = requirements.resolve() - assert file.exists(), f"{prefix} {file} not found, check failed." - requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude] - elif isinstance(requirements, str): - requirements = [requirements] - - pkgs = [] - for r in requirements: - r_stripped = r.rpartition("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo' - match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped) - name, required = match[1], match[2].strip() if match[2] else "" - try: - assert check_version(metadata.version(name), required) # exception if requirements not met - except (AssertionError, metadata.PackageNotFoundError): - pkgs.append(r) - - @Retry(times=2, delay=1) - def attempt_install(packages, commands, use_uv): - """Attempt package installation with uv if available, falling back to pip.""" - if use_uv: - base = f"uv pip install --no-cache-dir {packages} {commands} --index-strategy=unsafe-best-match --break-system-packages --prerelease=allow" - try: - return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE).decode() - except subprocess.CalledProcessError as e: - if e.stderr and "No virtual environment found" in e.stderr.decode(): - return subprocess.check_output( - base.replace("uv pip install", "uv pip install --system"), shell=True - ).decode() - raise - return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True).decode() - - s = " ".join(f'"{x}"' for x in pkgs) # console string - if s: - if install and AUTOINSTALL: # check environment variable - # Note uv fails on arm64 macOS and Raspberry Pi runners - n = len(pkgs) # number of packages updates - LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...") - try: - t = time.time() - assert ONLINE, "AutoUpdate skipped (offline)" - LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv())) - dt = time.time() - t - LOGGER.info(f"{prefix} AutoUpdate success ✅ {dt:.1f}s") - LOGGER.warning( - f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n" - ) - except Exception as e: - LOGGER.warning(f"{prefix} ❌ {e}") - return False - else: - return False - - return True - - -def check_torchvision(): - """ - Check the installed versions of PyTorch and Torchvision to ensure they're compatible. - - This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according - to the compatibility table based on: https://github.com/pytorch/vision#installation. - """ - compatibility_table = { - "2.7": ["0.22"], - "2.6": ["0.21"], - "2.5": ["0.20"], - "2.4": ["0.19"], - "2.3": ["0.18"], - "2.2": ["0.17"], - "2.1": ["0.16"], - "2.0": ["0.15"], - "1.13": ["0.14"], - "1.12": ["0.13"], - } - - # Check major and minor versions - v_torch = ".".join(torch.__version__.split("+", 1)[0].split(".")[:2]) - if v_torch in compatibility_table: - compatible_versions = compatibility_table[v_torch] - v_torchvision = ".".join(TORCHVISION_VERSION.split("+", 1)[0].split(".")[:2]) - if all(v_torchvision != v for v in compatible_versions): - LOGGER.warning( - f"torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n" - f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or " - "'pip install -U torch torchvision' to update both.\n" - "For a full compatibility table see https://github.com/pytorch/vision#installation" - ) - - -def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""): - """ - Check file(s) for acceptable suffix. - - Args: - file (str | List[str]): File or list of files to check. - suffix (str | tuple): Acceptable suffix or tuple of suffixes. - msg (str): Additional message to display in case of error. - """ - if file and suffix: - if isinstance(suffix, str): - suffix = {suffix} - for f in file if isinstance(file, (list, tuple)) else [file]: - if s := str(f).rpartition(".")[-1].lower().strip(): # file suffix - assert f".{s}" in suffix, f"{msg}{f} acceptable suffix is {suffix}, not .{s}" - - -def check_yolov5u_filename(file: str, verbose: bool = True): - """ - Replace legacy YOLOv5 filenames with updated YOLOv5u filenames. - - Args: - file (str): Filename to check and potentially update. - verbose (bool): Whether to print information about the replacement. - - Returns: - (str): Updated filename. - """ - if "yolov3" in file or "yolov5" in file: - if "u.yaml" in file: - file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml - elif ".pt" in file and "u" not in file: - original_file = file - file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt - file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt - file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt - if file != original_file and verbose: - LOGGER.info( - f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are " - f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs " - f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n" - ) - return file - - -def check_model_file_from_stem(model="yolo11n"): - """ - Return a model filename from a valid model stem. - - Args: - model (str): Model stem to check. - - Returns: - (str | Path): Model filename with appropriate suffix. - """ - path = Path(model) - if not path.suffix and path.stem in downloads.GITHUB_ASSETS_STEMS: - return path.with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt - return model - - -def check_file(file, suffix="", download=True, download_dir=".", hard=True): - """ - Search/download file (if necessary), check suffix (if provided), and return path. - - Args: - file (str): File name or path. - suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file. - download (bool): Whether to download the file if it doesn't exist locally. - download_dir (str): Directory to download the file to. - hard (bool): Whether to raise an error if the file is not found. - - Returns: - (str): Path to the file. - """ - check_suffix(file, suffix) # optional - file = str(file).strip() # convert to string and strip spaces - file = check_yolov5u_filename(file) # yolov5n -> yolov5nu - if ( - not file - or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10 - or file.lower().startswith("grpc://") - ): # file exists or gRPC Triton images - return file - elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download - url = file # warning: Pathlib turns :// -> :/ - file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth - if file.exists(): - LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists - else: - downloads.safe_download(url=url, file=file, unzip=False) - return str(file) - else: # search - files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file - if not files and hard: - raise FileNotFoundError(f"'{file}' does not exist") - elif len(files) > 1 and hard: - raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}") - return files[0] if len(files) else [] # return file - - -def check_yaml(file, suffix=(".yaml", ".yml"), hard=True): - """ - Search/download YAML file (if necessary) and return path, checking suffix. - - Args: - file (str | Path): File name or path. - suffix (tuple): Tuple of acceptable YAML file suffixes. - hard (bool): Whether to raise an error if the file is not found or multiple files are found. - - Returns: - (str): Path to the YAML file. - """ - return check_file(file, suffix, hard=hard) - - -def check_is_path_safe(basedir, path): - """ - Check if the resolved path is under the intended directory to prevent path traversal. - - Args: - basedir (Path | str): The intended directory. - path (Path | str): The path to check. - - Returns: - (bool): True if the path is safe, False otherwise. - """ - base_dir_resolved = Path(basedir).resolve() - path_resolved = Path(path).resolve() - - return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts - - -@functools.lru_cache -def check_imshow(warn=False): - """ - Check if environment supports image displays. - - Args: - warn (bool): Whether to warn if environment doesn't support image displays. - - Returns: - (bool): True if environment supports image displays, False otherwise. - """ - try: - if LINUX: - assert not IS_COLAB and not IS_KAGGLE - assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set." - cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image - cv2.waitKey(1) - cv2.destroyAllWindows() - cv2.waitKey(1) - return True - except Exception as e: - if warn: - LOGGER.warning(f"Environment does not support cv2.imshow() or PIL Image.show()\n{e}") - return False - - -def check_yolo(verbose=True, device=""): - """ - Return a human-readable YOLO software and hardware summary. - - Args: - verbose (bool): Whether to print verbose information. - device (str | torch.device): Device to use for YOLO. - """ - import psutil - - from ultralytics.utils.torch_utils import select_device - - if IS_COLAB: - shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory - - if verbose: - # System info - gib = 1 << 30 # bytes per GiB - ram = psutil.virtual_memory().total - total, used, free = shutil.disk_usage("/") - s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)" - try: - from IPython import display - - display.clear_output() # clear display if notebook - except ImportError: - pass - else: - s = "" - - select_device(device=device, newline=False) - LOGGER.info(f"Setup complete ✅ {s}") - - -def collect_system_info(): - """ - Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA. - - Returns: - (dict): Dictionary containing system information. - """ - import psutil - - from ultralytics.utils import ENVIRONMENT # scope to avoid circular import - from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info - - gib = 1 << 30 # bytes per GiB - cuda = torch.cuda.is_available() - check_yolo() - total, used, free = shutil.disk_usage("/") - - info_dict = { - "OS": platform.platform(), - "Environment": ENVIRONMENT, - "Python": PYTHON_VERSION, - "Install": "git" if IS_GIT_DIR else "pip" if IS_PIP_PACKAGE else "other", - "Path": str(ROOT), - "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB", - "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB", - "CPU": get_cpu_info(), - "CPU count": os.cpu_count(), - "GPU": get_gpu_info(index=0) if cuda else None, - "GPU count": torch.cuda.device_count() if cuda else None, - "CUDA": torch.version.cuda if cuda else None, - } - LOGGER.info("\n" + "\n".join(f"{k:<20}{v}" for k, v in info_dict.items()) + "\n") - - package_info = {} - for r in parse_requirements(package="ultralytics"): - try: - current = metadata.version(r.name) - is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ " - except metadata.PackageNotFoundError: - current = "(not installed)" - is_met = "❌ " - package_info[r.name] = f"{is_met}{current}{r.specifier}" - LOGGER.info(f"{r.name:<20}{package_info[r.name]}") - - info_dict["Package Info"] = package_info - - if is_github_action_running(): - github_info = { - "RUNNER_OS": os.getenv("RUNNER_OS"), - "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"), - "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"), - "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"), - "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"), - "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"), - } - LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items())) - info_dict["GitHub Info"] = github_info - - return info_dict - - -def check_amp(model): - """ - Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model. - - If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP - results, so AMP will be disabled during training. - - Args: - model (torch.nn.Module): A YOLO model instance. - - Returns: - (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False. - - Examples: - >>> from ultralytics import YOLO - >>> from ultralytics.utils.checks import check_amp - >>> model = YOLO("yolo11n.pt").model.cuda() - >>> check_amp(model) - """ - from ultralytics.utils.torch_utils import autocast - - device = next(model.parameters()).device # get model device - prefix = colorstr("AMP: ") - if device.type in {"cpu", "mps"}: - return False # AMP only used on CUDA devices - else: - # GPUs that have issues with AMP - pattern = re.compile( - r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE - ) - - gpu = torch.cuda.get_device_name(device) - if bool(pattern.search(gpu)): - LOGGER.warning( - f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause " - f"NaN losses or zero-mAP results, so AMP will be disabled during training." - ) - return False - - def amp_allclose(m, im): - """All close FP32 vs AMP results.""" - batch = [im] * 8 - imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64 - a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference - with autocast(enabled=True): - b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference - del m - return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance - - im = ASSETS / "bus.jpg" # image to check - LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...") - warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False." - try: - from ultralytics import YOLO - - assert amp_allclose(YOLO("yolo11n.pt"), im) - LOGGER.info(f"{prefix}checks passed ✅") - except ConnectionError: - LOGGER.warning(f"{prefix}checks skipped. Offline and unable to download YOLO11n for AMP checks. {warning_msg}") - except (AttributeError, ModuleNotFoundError): - LOGGER.warning( - f"{prefix}checks skipped. " - f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}" - ) - except AssertionError: - LOGGER.error( - f"{prefix}checks failed. Anomalies were detected with AMP on your system that may lead to " - f"NaN losses or zero-mAP results, so AMP will be disabled during training." - ) - return False - return True - - -def git_describe(path=ROOT): # path must be a directory - """ - Return human-readable git description, i.e. v5.0-5-g3e25f1e https://git-scm.com/docs/git-describe. - - Args: - path (Path): Path to git repository. - - Returns: - (str): Human-readable git description. - """ - try: - return subprocess.check_output(f"git -C {path} describe --tags --long --always", shell=True).decode()[:-1] - except Exception: - return "" - - -def print_args(args: Optional[dict] = None, show_file=True, show_func=False): - """ - Print function arguments (optional args dict). - - Args: - args (dict, optional): Arguments to print. - show_file (bool): Whether to show the file name. - show_func (bool): Whether to show the function name. - """ - - def strip_auth(v): - """Clean longer Ultralytics HUB URLs by stripping potential authentication information.""" - return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v - - x = inspect.currentframe().f_back # previous frame - file, _, func, _, _ = inspect.getframeinfo(x) - if args is None: # get args automatically - args, _, _, frm = inspect.getargvalues(x) - args = {k: v for k, v in frm.items() if k in args} - try: - file = Path(file).resolve().relative_to(ROOT).with_suffix("") - except ValueError: - file = Path(file).stem - s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") - LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in sorted(args.items()))) - - -def cuda_device_count() -> int: - """ - Get the number of NVIDIA GPUs available in the environment. - - Returns: - (int): The number of NVIDIA GPUs available. - """ - if IS_JETSON: - # NVIDIA Jetson does not fully support nvidia-smi and therefore use PyTorch instead - return torch.cuda.device_count() - else: - try: - # Run the nvidia-smi command and capture its output - output = subprocess.check_output( - ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8" - ) - - # Take the first line and strip any leading/trailing white space - first_line = output.strip().split("\n", 1)[0] - - return int(first_line) - except (subprocess.CalledProcessError, FileNotFoundError, ValueError): - # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available - return 0 - - -def cuda_is_available() -> bool: - """ - Check if CUDA is available in the environment. - - Returns: - (bool): True if one or more NVIDIA GPUs are available, False otherwise. - """ - return cuda_device_count() > 0 - - -def is_rockchip(): - """ - Check if the current environment is running on a Rockchip SoC. - - Returns: - (bool): True if running on a Rockchip SoC, False otherwise. - """ - if LINUX and ARM64: - try: - with open("/proc/device-tree/compatible") as f: - dev_str = f.read() - *_, soc = dev_str.split(",") - if soc.replace("\x00", "") in RKNN_CHIPS: - return True - except OSError: - return False - else: - return False - - -def is_intel(): - """ - Check if the system has Intel hardware (CPU or GPU). - - Returns: - (bool): True if Intel hardware is detected, False otherwise. - """ - from ultralytics.utils.torch_utils import get_cpu_info - - # Check CPU - if "intel" in get_cpu_info().lower(): - return True - - # Check GPU via xpu-smi - try: - result = subprocess.run(["xpu-smi", "discovery"], capture_output=True, text=True, timeout=5) - return "intel" in result.stdout.lower() - except (subprocess.TimeoutExpired, FileNotFoundError, subprocess.SubprocessError): - return False - - -def is_sudo_available() -> bool: - """ - Check if the sudo command is available in the environment. - - Returns: - (bool): True if the sudo command is available, False otherwise. - """ - if WINDOWS: - return False - cmd = "sudo --version" - return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0 - - -# Run checks and define constants -check_python("3.8", hard=False, verbose=True) # check python version -check_torchvision() # check torch-torchvision compatibility - -# Define constants -IS_PYTHON_3_8 = PYTHON_VERSION.startswith("3.8") -IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12") -IS_PYTHON_3_13 = PYTHON_VERSION.startswith("3.13") - -IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False) -IS_PYTHON_MINIMUM_3_12 = check_python("3.12", hard=False) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/dist.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/dist.py deleted file mode 100644 index d117dcb..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/dist.py +++ /dev/null @@ -1,119 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import os -import shutil -import sys -import tempfile - -from . import USER_CONFIG_DIR -from .torch_utils import TORCH_1_9 - - -def find_free_network_port() -> int: - """ - Find a free port on localhost. - - It is useful in single-node training when we don't want to connect to a real main node but have to set the - `MASTER_PORT` environment variable. - - Returns: - (int): The available network port number. - """ - import socket - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] # port - - -def generate_ddp_file(trainer): - """ - Generate a DDP (Distributed Data Parallel) file for multi-GPU training. - - This function creates a temporary Python file that enables distributed training across multiple GPUs. - The file contains the necessary configuration to initialize the trainer in a distributed environment. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments. - Must have args attribute and be a class instance. - - Returns: - (str): Path to the generated temporary DDP file. - - Notes: - The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes: - - Trainer class import - - Configuration overrides from the trainer arguments - - Model path configuration - - Training initialization code - """ - module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1) - - content = f""" -# Ultralytics Multi-GPU training temp file (should be automatically deleted after use) -overrides = {vars(trainer.args)} - -if __name__ == "__main__": - from {module} import {name} - from ultralytics.utils import DEFAULT_CFG_DICT - - cfg = DEFAULT_CFG_DICT.copy() - cfg.update(save_dir='') # handle the extra key 'save_dir' - trainer = {name}(cfg=cfg, overrides=overrides) - trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}" - results = trainer.train() -""" - (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True) - with tempfile.NamedTemporaryFile( - prefix="_temp_", - suffix=f"{id(trainer)}.py", - mode="w+", - encoding="utf-8", - dir=USER_CONFIG_DIR / "DDP", - delete=False, - ) as file: - file.write(content) - return file.name - - -def generate_ddp_command(world_size: int, trainer): - """ - Generate command for distributed training. - - Args: - world_size (int): Number of processes to spawn for distributed training. - trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training. - - Returns: - cmd (List[str]): The command to execute for distributed training. - file (str): Path to the temporary file created for DDP training. - """ - import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218 - - if not trainer.resume: - shutil.rmtree(trainer.save_dir) # remove the save_dir - file = generate_ddp_file(trainer) - dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch" - port = find_free_network_port() - cmd = [sys.executable, "-m", dist_cmd, "--nproc_per_node", f"{world_size}", "--master_port", f"{port}", file] - return cmd, file - - -def ddp_cleanup(trainer, file): - """ - Delete temporary file if created during distributed data parallel (DDP) training. - - This function checks if the provided file contains the trainer's ID in its name, indicating it was created - as a temporary file for DDP training, and deletes it if so. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training. - file (str): Path to the file that might need to be deleted. - - Examples: - >>> trainer = YOLOTrainer() - >>> file = "/tmp/ddp_temp_123456789.py" - >>> ddp_cleanup(trainer, file) - """ - if f"{id(trainer)}.py" in file: # if temp_file suffix in file - os.remove(file) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/downloads.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/downloads.py deleted file mode 100644 index 7bffa56..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/downloads.py +++ /dev/null @@ -1,523 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import re -import shutil -import subprocess -from itertools import repeat -from multiprocessing.pool import ThreadPool -from pathlib import Path -from typing import List, Tuple -from urllib import parse, request - -from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file - -# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets -GITHUB_ASSETS_REPO = "ultralytics/assets" -GITHUB_ASSETS_NAMES = frozenset( - [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")] - + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")] - + [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently - + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")] - + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")] - + [f"yolov8{k}-world.pt" for k in "smlx"] - + [f"yolov8{k}-worldv2.pt" for k in "smlx"] - + [f"yoloe-v8{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")] - + [f"yoloe-11{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")] - + [f"yolov9{k}.pt" for k in "tsmce"] - + [f"yolov10{k}.pt" for k in "nsmblx"] - + [f"yolo_nas_{k}.pt" for k in "sml"] - + [f"sam_{k}.pt" for k in "bl"] - + [f"sam2_{k}.pt" for k in "blst"] - + [f"sam2.1_{k}.pt" for k in "blst"] - + [f"FastSAM-{k}.pt" for k in "sx"] - + [f"rtdetr-{k}.pt" for k in "lx"] - + [ - "mobile_sam.pt", - "mobileclip_blt.ts", - "yolo11n-grayscale.pt", - "calibration_image_sample_data_20x128x128x3_float32.npy.zip", - ] -) -GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAMES) - - -def is_url(url, check: bool = False) -> bool: - """ - Validate if the given string is a URL and optionally check if the URL exists online. - - Args: - url (str): The string to be validated as a URL. - check (bool, optional): If True, performs an additional check to see if the URL exists online. - - Returns: - (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online. - - Examples: - >>> valid = is_url("https://www.example.com") - >>> valid_and_exists = is_url("https://www.example.com", check=True) - """ - try: - url = str(url) - result = parse.urlparse(url) - assert all([result.scheme, result.netloc]) # check if is url - if check: - with request.urlopen(url) as response: - return response.getcode() == 200 # check if exists online - return True - except Exception: - return False - - -def delete_dsstore(path, files_to_delete=(".DS_Store", "__MACOSX")): - """ - Delete all specified system files in a directory. - - Args: - path (str | Path): The directory path where the files should be deleted. - files_to_delete (tuple): The files to be deleted. - - Examples: - >>> from ultralytics.utils.downloads import delete_dsstore - >>> delete_dsstore("path/to/dir") - - Notes: - ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They - are hidden system files and can cause issues when transferring files between different operating systems. - """ - for file in files_to_delete: - matches = list(Path(path).rglob(file)) - LOGGER.info(f"Deleting {file} files: {matches}") - for f in matches: - f.unlink() - - -def zip_directory(directory, compress: bool = True, exclude=(".DS_Store", "__MACOSX"), progress: bool = True) -> Path: - """ - Zip the contents of a directory, excluding specified files. - - The resulting zip file is named after the directory and placed alongside it. - - Args: - directory (str | Path): The path to the directory to be zipped. - compress (bool): Whether to compress the files while zipping. - exclude (tuple, optional): A tuple of filename strings to be excluded. - progress (bool, optional): Whether to display a progress bar. - - Returns: - (Path): The path to the resulting zip file. - - Examples: - >>> from ultralytics.utils.downloads import zip_directory - >>> file = zip_directory("path/to/dir") - """ - from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile - - delete_dsstore(directory) - directory = Path(directory) - if not directory.is_dir(): - raise FileNotFoundError(f"Directory '{directory}' does not exist.") - - # Zip with progress bar - files_to_zip = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] - zip_file = directory.with_suffix(".zip") - compression = ZIP_DEFLATED if compress else ZIP_STORED - with ZipFile(zip_file, "w", compression) as f: - for file in TQDM(files_to_zip, desc=f"Zipping {directory} to {zip_file}...", unit="file", disable=not progress): - f.write(file, file.relative_to(directory)) - - return zip_file # return path to zip file - - -def unzip_file( - file, - path=None, - exclude=(".DS_Store", "__MACOSX"), - exist_ok: bool = False, - progress: bool = True, -) -> Path: - """ - Unzip a *.zip file to the specified path, excluding specified files. - - If the zipfile does not contain a single top-level directory, the function will create a new - directory with the same name as the zipfile (without the extension) to extract its contents. - If a path is not provided, the function will use the parent directory of the zipfile as the default path. - - Args: - file (str | Path): The path to the zipfile to be extracted. - path (str | Path, optional): The path to extract the zipfile to. - exclude (tuple, optional): A tuple of filename strings to be excluded. - exist_ok (bool, optional): Whether to overwrite existing contents if they exist. - progress (bool, optional): Whether to display a progress bar. - - Returns: - (Path): The path to the directory where the zipfile was extracted. - - Raises: - BadZipFile: If the provided file does not exist or is not a valid zipfile. - - Examples: - >>> from ultralytics.utils.downloads import unzip_file - >>> directory = unzip_file("path/to/file.zip") - """ - from zipfile import BadZipFile, ZipFile, is_zipfile - - if not (Path(file).exists() and is_zipfile(file)): - raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.") - if path is None: - path = Path(file).parent # default path - - # Unzip the file contents - with ZipFile(file) as zipObj: - files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)] - top_level_dirs = {Path(f).parts[0] for f in files} - - # Decide to unzip directly or unzip into a directory - unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/")) - if unzip_as_dir: - # Zip has 1 top-level directory - extract_path = path # i.e. ../datasets - path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/ - else: - # Zip has multiple files at top level - path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/ - - # Check if destination directory already exists and contains files - if path.exists() and any(path.iterdir()) and not exist_ok: - # If it exists and is not empty, return the path without unzipping - LOGGER.warning(f"Skipping {file} unzip as destination directory {path} is not empty.") - return path - - for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="file", disable=not progress): - # Ensure the file is within the extract_path to avoid path traversal security vulnerability - if ".." in Path(f).parts: - LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.") - continue - zipObj.extract(f, extract_path) - - return path # return unzip dir - - -def check_disk_space( - url: str = "https://ultralytics.com/assets/coco8.zip", - path=Path.cwd(), - sf: float = 1.5, - hard: bool = True, -) -> bool: - """ - Check if there is sufficient disk space to download and store a file. - - Args: - url (str, optional): The URL to the file. - path (str | Path, optional): The path or drive to check the available free space on. - sf (float, optional): Safety factor, the multiplier for the required free space. - hard (bool, optional): Whether to throw an error or not on insufficient disk space. - - Returns: - (bool): True if there is sufficient disk space, False otherwise. - """ - import requests # slow import - - try: - r = requests.head(url) # response - assert r.status_code < 400, f"URL error for {url}: {r.status_code} {r.reason}" # check response - except Exception: - return True # requests issue, default to True - - # Check file size - gib = 1 << 30 # bytes per GiB - data = int(r.headers.get("Content-Length", 0)) / gib # file size (GB) - total, used, free = (x / gib for x in shutil.disk_usage(path)) # bytes - - if data * sf < free: - return True # sufficient space - - # Insufficient space - text = ( - f"Insufficient free disk space {free:.1f} GB < {data * sf:.3f} GB required, " - f"Please free {data * sf - free:.1f} GB additional disk space and try again." - ) - if hard: - raise MemoryError(text) - LOGGER.warning(text) - return False - - -def get_google_drive_file_info(link: str) -> Tuple[str, str]: - """ - Retrieve the direct download link and filename for a shareable Google Drive file link. - - Args: - link (str): The shareable link of the Google Drive file. - - Returns: - url (str): Direct download URL for the Google Drive file. - filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None. - - Examples: - >>> from ultralytics.utils.downloads import get_google_drive_file_info - >>> link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link" - >>> url, filename = get_google_drive_file_info(link) - """ - import requests # slow import - - file_id = link.split("/d/")[1].split("/view", 1)[0] - drive_url = f"https://drive.google.com/uc?export=download&id={file_id}" - filename = None - - # Start session - with requests.Session() as session: - response = session.get(drive_url, stream=True) - if "quota exceeded" in str(response.content.lower()): - raise ConnectionError( - emojis( - f"❌ Google Drive file download quota exceeded. " - f"Please try again later or download this file manually at {link}." - ) - ) - for k, v in response.cookies.items(): - if k.startswith("download_warning"): - drive_url += f"&confirm={v}" # v is token - if cd := response.headers.get("content-disposition"): - filename = re.findall('filename="(.+)"', cd)[0] - return drive_url, filename - - -def safe_download( - url, - file=None, - dir=None, - unzip: bool = True, - delete: bool = False, - curl: bool = False, - retry: int = 3, - min_bytes: float = 1e0, - exist_ok: bool = False, - progress: bool = True, -): - """ - Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. - - Args: - url (str): The URL of the file to be downloaded. - file (str, optional): The filename of the downloaded file. - If not provided, the file will be saved with the same name as the URL. - dir (str | Path, optional): The directory to save the downloaded file. - If not provided, the file will be saved in the current working directory. - unzip (bool, optional): Whether to unzip the downloaded file. - delete (bool, optional): Whether to delete the downloaded file after unzipping. - curl (bool, optional): Whether to use curl command line tool for downloading. - retry (int, optional): The number of times to retry the download in case of failure. - min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered - a successful download. - exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. - progress (bool, optional): Whether to display a progress bar during the download. - - Returns: - (Path | str): The path to the downloaded file or extracted directory. - - Examples: - >>> from ultralytics.utils.downloads import safe_download - >>> link = "https://ultralytics.com/assets/bus.jpg" - >>> path = safe_download(link) - """ - gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link - if gdrive: - url, file = get_google_drive_file_info(url) - - f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename - if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10) - f = Path(url) # filename - elif not f.is_file(): # URL and file do not exist - uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url - "https://github.com/ultralytics/assets/releases/download/v0.0.0/", - "https://ultralytics.com/assets/", # assets alias - ) - desc = f"Downloading {uri} to '{f}'" - f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing - check_disk_space(url, path=f.parent) - curl_installed = shutil.which("curl") - for i in range(retry + 1): - try: - if (curl or i > 0) and curl_installed: # curl download with retry, continue - s = "sS" * (not progress) # silent - r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode - assert r == 0, f"Curl return value {r}" - else: # urllib download - # torch.hub.download_url_to_file(url, f, progress=progress) # do not use as progress tqdm differs - with request.urlopen(url) as response, TQDM( - total=int(response.getheader("Content-Length", 0)), - desc=desc, - disable=not progress, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as pbar: - with open(f, "wb") as f_opened: - for data in response: - f_opened.write(data) - pbar.update(len(data)) - - if f.exists(): - if f.stat().st_size > min_bytes: - break # success - f.unlink() # remove partial downloads - except Exception as e: - if i == 0 and not is_online(): - raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e - elif i >= retry: - raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e - LOGGER.warning(f"Download failure, retrying {i + 1}/{retry} {uri}...") - - if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}: - from zipfile import is_zipfile - - unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place - if is_zipfile(f): - unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip - elif f.suffix in {".tar", ".gz"}: - LOGGER.info(f"Unzipping {f} to {unzip_dir}...") - subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True) - if delete: - f.unlink() # remove zip - return unzip_dir - return f - - -def get_github_assets( - repo: str = "ultralytics/assets", - version: str = "latest", - retry: bool = False, -) -> Tuple[str, List[str]]: - """ - Retrieve the specified version's tag and assets from a GitHub repository. - - If the version is not specified, the function fetches the latest release assets. - - Args: - repo (str, optional): The GitHub repository in the format 'owner/repo'. - version (str, optional): The release version to fetch assets from. - retry (bool, optional): Flag to retry the request in case of a failure. - - Returns: - tag (str): The release tag. - assets (List[str]): A list of asset names. - - Examples: - >>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest") - """ - import requests # slow import - - if version != "latest": - version = f"tags/{version}" # i.e. tags/v6.2 - url = f"https://api.github.com/repos/{repo}/releases/{version}" - r = requests.get(url) # github api - if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded - r = requests.get(url) # try again - if r.status_code != 200: - LOGGER.warning(f"GitHub assets check failure for {url}: {r.status_code} {r.reason}") - return "", [] - data = r.json() - return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...] - - -def attempt_download_asset(file, repo: str = "ultralytics/assets", release: str = "v8.3.0", **kwargs) -> str: - """ - Attempt to download a file from GitHub release assets if it is not found locally. - - Args: - file (str | Path): The filename or file path to be downloaded. - repo (str, optional): The GitHub repository in the format 'owner/repo'. - release (str, optional): The specific release version to be downloaded. - **kwargs (Any): Additional keyword arguments for the download process. - - Returns: - (str): The path to the downloaded file. - - Examples: - >>> file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest") - """ - from ultralytics.utils import SETTINGS # scoped for circular import - - # YOLOv3/5u updates - file = str(file) - file = checks.check_yolov5u_filename(file) - file = Path(file.strip().replace("'", "")) - if file.exists(): - return str(file) - elif (SETTINGS["weights_dir"] / file).exists(): - return str(SETTINGS["weights_dir"] / file) - else: - # URL specified - name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc. - download_url = f"https://github.com/{repo}/releases/download" - if str(file).startswith(("http:/", "https:/")): # download - url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ - file = url2file(name) # parse authentication https://url.com/file.txt?auth... - if Path(file).is_file(): - LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists - else: - safe_download(url=url, file=file, min_bytes=1e5, **kwargs) - - elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES: - safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs) - - else: - tag, assets = get_github_assets(repo, release) - if not assets: - tag, assets = get_github_assets(repo) # latest release - if name in assets: - safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs) - - return str(file) - - -def download( - url, - dir=Path.cwd(), - unzip: bool = True, - delete: bool = False, - curl: bool = False, - threads: int = 1, - retry: int = 3, - exist_ok: bool = False, -): - """ - Download files from specified URLs to a given directory. - - Supports concurrent downloads if multiple threads are specified. - - Args: - url (str | List[str]): The URL or list of URLs of the files to be downloaded. - dir (Path, optional): The directory where the files will be saved. - unzip (bool, optional): Flag to unzip the files after downloading. - delete (bool, optional): Flag to delete the zip files after extraction. - curl (bool, optional): Flag to use curl for downloading. - threads (int, optional): Number of threads to use for concurrent downloads. - retry (int, optional): Number of retries in case of download failure. - exist_ok (bool, optional): Whether to overwrite existing contents during unzipping. - - Examples: - >>> download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True) - """ - dir = Path(dir) - dir.mkdir(parents=True, exist_ok=True) # make directory - if threads > 1: - with ThreadPool(threads) as pool: - pool.map( - lambda x: safe_download( - url=x[0], - dir=x[1], - unzip=unzip, - delete=delete, - curl=curl, - retry=retry, - exist_ok=exist_ok, - progress=threads <= 1, - ), - zip(url, repeat(dir)), - ) - pool.close() - pool.join() - else: - for u in [url] if isinstance(url, (str, Path)) else url: - safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/errors.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/errors.py deleted file mode 100644 index ed9fd9e..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/errors.py +++ /dev/null @@ -1,43 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.utils import emojis - - -class HUBModelError(Exception): - """ - Exception raised when a model cannot be found or retrieved from Ultralytics HUB. - - This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO. - The error message is processed to include emojis for better user experience. - - Attributes: - message (str): The error message displayed when the exception is raised. - - Methods: - __init__: Initialize the HUBModelError with a custom message. - - Examples: - >>> try: - ... # Code that might fail to find a model - ... raise HUBModelError("Custom model not found message") - ... except HUBModelError as e: - ... print(e) # Displays the emoji-enhanced error message - """ - - def __init__(self, message: str = "Model not found. Please check model URL and try again."): - """ - Initialize a HUBModelError exception. - - This exception is raised when a requested model is not found or cannot be retrieved from Ultralytics HUB. - The message is processed to include emojis for better user experience. - - Args: - message (str, optional): The error message to display when the exception is raised. - - Examples: - >>> try: - ... raise HUBModelError("Custom model error message") - ... except HUBModelError as e: - ... print(e) - """ - super().__init__(emojis(message)) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/export.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/export.py deleted file mode 100644 index 0a34951..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/export.py +++ /dev/null @@ -1,236 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import json -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union - -import torch - -from ultralytics.utils import IS_JETSON, LOGGER - - -def export_onnx( - torch_model: torch.nn.Module, - im: torch.Tensor, - onnx_file: str, - opset: int = 14, - input_names: List[str] = ["images"], - output_names: List[str] = ["output0"], - dynamic: Union[bool, Dict] = False, -) -> None: - """ - Export a PyTorch model to ONNX format. - - Args: - torch_model (torch.nn.Module): The PyTorch model to export. - im (torch.Tensor): Example input tensor for the model. - onnx_file (str): Path to save the exported ONNX file. - opset (int): ONNX opset version to use for export. - input_names (List[str]): List of input tensor names. - output_names (List[str]): List of output tensor names. - dynamic (bool | Dict, optional): Whether to enable dynamic axes. - - Notes: - Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12. - """ - torch.onnx.export( - torch_model, - im, - onnx_file, - verbose=False, - opset_version=opset, - do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic or None, - ) - - -def export_engine( - onnx_file: str, - engine_file: Optional[str] = None, - workspace: Optional[int] = None, - half: bool = False, - int8: bool = False, - dynamic: bool = False, - shape: Tuple[int, int, int, int] = (1, 3, 640, 640), - dla: Optional[int] = None, - dataset=None, - metadata: Optional[Dict] = None, - verbose: bool = False, - prefix: str = "", -) -> None: - """ - Export a YOLO model to TensorRT engine format. - - Args: - onnx_file (str): Path to the ONNX file to be converted. - engine_file (str, optional): Path to save the generated TensorRT engine file. - workspace (int, optional): Workspace size in GB for TensorRT. - half (bool, optional): Enable FP16 precision. - int8 (bool, optional): Enable INT8 precision. - dynamic (bool, optional): Enable dynamic input shapes. - shape (Tuple[int, int, int, int], optional): Input shape (batch, channels, height, width). - dla (int, optional): DLA core to use (Jetson devices only). - dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration. - metadata (Dict, optional): Metadata to include in the engine file. - verbose (bool, optional): Enable verbose logging. - prefix (str, optional): Prefix for log messages. - - Raises: - ValueError: If DLA is enabled on non-Jetson devices or required precision is not set. - RuntimeError: If the ONNX file cannot be parsed. - - Notes: - TensorRT version compatibility is handled for workspace size and engine building. - INT8 calibration requires a dataset and generates a calibration cache. - Metadata is serialized and written to the engine file if provided. - """ - import tensorrt as trt # noqa - - engine_file = engine_file or Path(onnx_file).with_suffix(".engine") - - logger = trt.Logger(trt.Logger.INFO) - if verbose: - logger.min_severity = trt.Logger.Severity.VERBOSE - - # Engine builder - builder = trt.Builder(logger) - config = builder.create_builder_config() - workspace = int((workspace or 0) * (1 << 30)) - is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10 - if is_trt10 and workspace > 0: - config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) - elif workspace > 0: # TensorRT versions 7, 8 - config.max_workspace_size = workspace - flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) - network = builder.create_network(flag) - half = builder.platform_has_fast_fp16 and half - int8 = builder.platform_has_fast_int8 and int8 - - # Optionally switch to DLA if enabled - if dla is not None: - if not IS_JETSON: - raise ValueError("DLA is only available on NVIDIA Jetson devices") - LOGGER.info(f"{prefix} enabling DLA on core {dla}...") - if not half and not int8: - raise ValueError( - "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again." - ) - config.default_device_type = trt.DeviceType.DLA - config.DLA_core = int(dla) - config.set_flag(trt.BuilderFlag.GPU_FALLBACK) - - # Read ONNX file - parser = trt.OnnxParser(network, logger) - if not parser.parse_from_file(onnx_file): - raise RuntimeError(f"failed to load ONNX file: {onnx_file}") - - # Network inputs - inputs = [network.get_input(i) for i in range(network.num_inputs)] - outputs = [network.get_output(i) for i in range(network.num_outputs)] - for inp in inputs: - LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') - for out in outputs: - LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') - - if dynamic: - profile = builder.create_optimization_profile() - min_shape = (1, shape[1], 32, 32) # minimum input shape - max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape - for inp in inputs: - profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape) - config.add_optimization_profile(profile) - if int8: - config.set_calibration_profile(profile) - - LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}") - if int8: - config.set_flag(trt.BuilderFlag.INT8) - config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED - - class EngineCalibrator(trt.IInt8Calibrator): - """ - Custom INT8 calibrator for TensorRT engine optimization. - - This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration - using a dataset. It handles batch generation, caching, and calibration algorithm selection. - - Attributes: - dataset: Dataset for calibration. - data_iter: Iterator over the calibration dataset. - algo (trt.CalibrationAlgoType): Calibration algorithm type. - batch (int): Batch size for calibration. - cache (Path): Path to save the calibration cache. - - Methods: - get_algorithm: Get the calibration algorithm to use. - get_batch_size: Get the batch size to use for calibration. - get_batch: Get the next batch to use for calibration. - read_calibration_cache: Use existing cache instead of calibrating again. - write_calibration_cache: Write calibration cache to disk. - """ - - def __init__( - self, - dataset, # ultralytics.data.build.InfiniteDataLoader - cache: str = "", - ) -> None: - """Initialize the INT8 calibrator with dataset and cache path.""" - trt.IInt8Calibrator.__init__(self) - self.dataset = dataset - self.data_iter = iter(dataset) - self.algo = ( - trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2 - if dla is not None - else trt.CalibrationAlgoType.MINMAX_CALIBRATION - ) - self.batch = dataset.batch_size - self.cache = Path(cache) - - def get_algorithm(self) -> trt.CalibrationAlgoType: - """Get the calibration algorithm to use.""" - return self.algo - - def get_batch_size(self) -> int: - """Get the batch size to use for calibration.""" - return self.batch or 1 - - def get_batch(self, names) -> Optional[List[int]]: - """Get the next batch to use for calibration, as a list of device memory pointers.""" - try: - im0s = next(self.data_iter)["img"] / 255.0 - im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s - return [int(im0s.data_ptr())] - except StopIteration: - # Return None to signal to TensorRT there is no calibration data remaining - return None - - def read_calibration_cache(self) -> Optional[bytes]: - """Use existing cache instead of calibrating again, otherwise, implicitly return None.""" - if self.cache.exists() and self.cache.suffix == ".cache": - return self.cache.read_bytes() - - def write_calibration_cache(self, cache: bytes) -> None: - """Write calibration cache to disk.""" - _ = self.cache.write_bytes(cache) - - # Load dataset w/ builder (for batching) and calibrate - config.int8_calibrator = EngineCalibrator( - dataset=dataset, - cache=str(Path(onnx_file).with_suffix(".cache")), - ) - - elif half: - config.set_flag(trt.BuilderFlag.FP16) - - # Write file - build = builder.build_serialized_network if is_trt10 else builder.build_engine - with build(network, config) as engine, open(engine_file, "wb") as t: - # Metadata - if metadata is not None: - meta = json.dumps(metadata) - t.write(len(meta).to_bytes(4, byteorder="little", signed=True)) - t.write(meta.encode()) - # Model - t.write(engine if is_trt10 else engine.serialize()) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/files.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/files.py deleted file mode 100644 index 49a692d..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/files.py +++ /dev/null @@ -1,222 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import contextlib -import glob -import os -import shutil -import tempfile -from contextlib import contextmanager -from datetime import datetime -from pathlib import Path -from typing import Union - - -class WorkingDirectory(contextlib.ContextDecorator): - """ - A context manager and decorator for temporarily changing the working directory. - - This class allows for the temporary change of the working directory using a context manager or decorator. - It ensures that the original working directory is restored after the context or decorated function completes. - - Attributes: - dir (Path | str): The new directory to switch to. - cwd (Path): The original current working directory before the switch. - - Methods: - __enter__: Changes the current directory to the specified directory. - __exit__: Restores the original working directory on context exit. - - Examples: - Using as a context manager: - >>> with WorkingDirectory('/path/to/new/dir'): - >>> # Perform operations in the new directory - >>> pass - - Using as a decorator: - >>> @WorkingDirectory('/path/to/new/dir') - >>> def some_function(): - >>> # Perform operations in the new directory - >>> pass - """ - - def __init__(self, new_dir: Union[str, Path]): - """Initialize the WorkingDirectory context manager with the target directory.""" - self.dir = new_dir # new dir - self.cwd = Path.cwd().resolve() # current dir - - def __enter__(self): - """Change the current working directory to the specified directory upon entering the context.""" - os.chdir(self.dir) - - def __exit__(self, exc_type, exc_val, exc_tb): # noqa - """Restore the original working directory when exiting the context.""" - os.chdir(self.cwd) - - -@contextmanager -def spaces_in_path(path: Union[str, Path]): - """ - Context manager to handle paths with spaces in their names. - - If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes - the context code block, then copies the file/directory back to its original location. - - Args: - path (str | Path): The original path that may contain spaces. - - Yields: - (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the - original path. - - Examples: - >>> with spaces_in_path('/path/with spaces') as new_path: - >>> # Your code here - >>> pass - """ - # If path has spaces, replace them with underscores - if " " in str(path): - string = isinstance(path, str) # input type - path = Path(path) - - # Create a temporary directory and construct the new path - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) / path.name.replace(" ", "_") - - # Copy file/directory - if path.is_dir(): - shutil.copytree(path, tmp_path) - elif path.is_file(): - tmp_path.parent.mkdir(parents=True, exist_ok=True) - shutil.copy2(path, tmp_path) - - try: - # Yield the temporary path - yield str(tmp_path) if string else tmp_path - - finally: - # Copy file/directory back - if tmp_path.is_dir(): - shutil.copytree(tmp_path, path, dirs_exist_ok=True) - elif tmp_path.is_file(): - shutil.copy2(tmp_path, path) # Copy back the file - - else: - # If there are no spaces, just yield the original path - yield path - - -def increment_path(path: Union[str, Path], exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path: - """ - Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. - - If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to - the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the - number will be appended directly to the end of the path. - - Args: - path (str | Path): Path to increment. - exist_ok (bool, optional): If True, the path will not be incremented and returned as-is. - sep (str, optional): Separator to use between the path and the incrementation number. - mkdir (bool, optional): Create a directory if it does not exist. - - Returns: - (Path): Incremented path. - - Examples: - Increment a directory path: - >>> from pathlib import Path - >>> path = Path("runs/exp") - >>> new_path = increment_path(path) - >>> print(new_path) - runs/exp2 - - Increment a file path: - >>> path = Path("runs/exp/results.txt") - >>> new_path = increment_path(path) - >>> print(new_path) - runs/exp/results2.txt - """ - path = Path(path) # os-agnostic - if path.exists() and not exist_ok: - path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") - - # Method 1 - for n in range(2, 9999): - p = f"{path}{sep}{n}{suffix}" # increment path - if not os.path.exists(p): - break - path = Path(p) - - if mkdir: - path.mkdir(parents=True, exist_ok=True) # make directory - - return path - - -def file_age(path: Union[str, Path] = __file__) -> int: - """Return days since the last modification of the specified file.""" - dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta - return dt.days # + dt.seconds / 86400 # fractional days - - -def file_date(path: Union[str, Path] = __file__) -> str: - """Return the file modification date in 'YYYY-M-D' format.""" - t = datetime.fromtimestamp(Path(path).stat().st_mtime) - return f"{t.year}-{t.month}-{t.day}" - - -def file_size(path: Union[str, Path]) -> float: - """Return the size of a file or directory in megabytes (MB).""" - if isinstance(path, (str, Path)): - mb = 1 << 20 # bytes to MiB (1024 ** 2) - path = Path(path) - if path.is_file(): - return path.stat().st_size / mb - elif path.is_dir(): - return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb - return 0.0 - - -def get_latest_run(search_dir: str = ".") -> str: - """Return the path to the most recent 'last.pt' file in the specified directory for resuming training.""" - last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True) - return max(last_list, key=os.path.getctime) if last_list else "" - - -def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False): - """ - Update and re-save specified YOLO models in an 'updated_models' subdirectory. - - Args: - model_names (tuple, optional): Model filenames to update. - source_dir (Path, optional): Directory containing models and target subdirectory. - update_names (bool, optional): Update model names from a data YAML. - - Examples: - Update specified YOLO models and save them in 'updated_models' subdirectory: - >>> from ultralytics.utils.files import update_models - >>> model_names = ("yolo11n.pt", "yolov8s.pt") - >>> update_models(model_names, source_dir=Path("/models"), update_names=True) - """ - from ultralytics import YOLO - from ultralytics.nn.autobackend import default_class_names - - target_dir = source_dir / "updated_models" - target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists - - for model_name in model_names: - model_path = source_dir / model_name - print(f"Loading model from {model_path}") - - # Load model - model = YOLO(model_path) - model.half() - if update_names: # update model names from a dataset YAML - model.model.names = default_class_names("coco8.yaml") - - # Define new save path - save_path = target_dir / model_name - - # Save model using model.save() - print(f"Re-saving {model_name} model to {save_path}") - model.save(save_path) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/instance.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/instance.py deleted file mode 100644 index f0c23af..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/instance.py +++ /dev/null @@ -1,504 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from collections import abc -from itertools import repeat -from numbers import Number -from typing import List, Union - -import numpy as np - -from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh - - -def _ntuple(n): - """Create a function that converts input to n-tuple by repeating singleton values.""" - - def parse(x): - """Parse input to return n-tuple by repeating singleton values n times.""" - return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n)) - - return parse - - -to_2tuple = _ntuple(2) -to_4tuple = _ntuple(4) - -# `xyxy` means left top and right bottom -# `xywh` means center x, center y and width, height(YOLO format) -# `ltwh` means left top and width, height(COCO format) -_formats = ["xyxy", "xywh", "ltwh"] - -__all__ = ("Bboxes", "Instances") # tuple or list - - -class Bboxes: - """ - A class for handling bounding boxes in multiple formats. - - The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format - conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays. - - Attributes: - bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4). - format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh'). - - Methods: - convert: Convert bounding box format from one type to another. - areas: Calculate the area of bounding boxes. - mul: Multiply bounding box coordinates by scale factor(s). - add: Add offset to bounding box coordinates. - concatenate: Concatenate multiple Bboxes objects. - - Examples: - Create bounding boxes in YOLO format - >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format="xywh") - >>> bboxes.convert("xyxy") - >>> print(bboxes.areas()) - - Notes: - This class does not handle normalization or denormalization of bounding boxes. - """ - - def __init__(self, bboxes: np.ndarray, format: str = "xyxy") -> None: - """ - Initialize the Bboxes class with bounding box data in a specified format. - - Args: - bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,). - format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'. - """ - assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" - bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes - assert bboxes.ndim == 2 - assert bboxes.shape[1] == 4 - self.bboxes = bboxes - self.format = format - - def convert(self, format: str) -> None: - """ - Convert bounding box format from one type to another. - - Args: - format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'. - """ - assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}" - if self.format == format: - return - elif self.format == "xyxy": - func = xyxy2xywh if format == "xywh" else xyxy2ltwh - elif self.format == "xywh": - func = xywh2xyxy if format == "xyxy" else xywh2ltwh - else: - func = ltwh2xyxy if format == "xyxy" else ltwh2xywh - self.bboxes = func(self.bboxes) - self.format = format - - def areas(self) -> np.ndarray: - """Calculate the area of bounding boxes.""" - return ( - (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy - if self.format == "xyxy" - else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh - ) - - def mul(self, scale: Union[int, tuple, list]) -> None: - """ - Multiply bounding box coordinates by scale factor(s). - - Args: - scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to - all coordinates. - """ - if isinstance(scale, Number): - scale = to_4tuple(scale) - assert isinstance(scale, (tuple, list)) - assert len(scale) == 4 - self.bboxes[:, 0] *= scale[0] - self.bboxes[:, 1] *= scale[1] - self.bboxes[:, 2] *= scale[2] - self.bboxes[:, 3] *= scale[3] - - def add(self, offset: Union[int, tuple, list]) -> None: - """ - Add offset to bounding box coordinates. - - Args: - offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to - all coordinates. - """ - if isinstance(offset, Number): - offset = to_4tuple(offset) - assert isinstance(offset, (tuple, list)) - assert len(offset) == 4 - self.bboxes[:, 0] += offset[0] - self.bboxes[:, 1] += offset[1] - self.bboxes[:, 2] += offset[2] - self.bboxes[:, 3] += offset[3] - - def __len__(self) -> int: - """Return the number of bounding boxes.""" - return len(self.bboxes) - - @classmethod - def concatenate(cls, boxes_list: List["Bboxes"], axis: int = 0) -> "Bboxes": - """ - Concatenate a list of Bboxes objects into a single Bboxes object. - - Args: - boxes_list (List[Bboxes]): A list of Bboxes objects to concatenate. - axis (int, optional): The axis along which to concatenate the bounding boxes. - - Returns: - (Bboxes): A new Bboxes object containing the concatenated bounding boxes. - - Notes: - The input should be a list or tuple of Bboxes objects. - """ - assert isinstance(boxes_list, (list, tuple)) - if not boxes_list: - return cls(np.empty(0)) - assert all(isinstance(box, Bboxes) for box in boxes_list) - - if len(boxes_list) == 1: - return boxes_list[0] - return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis)) - - def __getitem__(self, index: Union[int, np.ndarray, slice]) -> "Bboxes": - """ - Retrieve a specific bounding box or a set of bounding boxes using indexing. - - Args: - index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes. - - Returns: - (Bboxes): A new Bboxes object containing the selected bounding boxes. - - Notes: - When using boolean indexing, make sure to provide a boolean array with the same length as the number of - bounding boxes. - """ - if isinstance(index, int): - return Bboxes(self.bboxes[index].reshape(1, -1)) - b = self.bboxes[index] - assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!" - return Bboxes(b) - - -class Instances: - """ - Container for bounding boxes, segments, and keypoints of detected objects in an image. - - This class provides a unified interface for handling different types of object annotations including bounding - boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping, - and format conversion. - - Attributes: - _bboxes (Bboxes): Internal object for handling bounding box operations. - keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible). - normalized (bool): Flag indicating whether the bounding box coordinates are normalized. - segments (np.ndarray): Segments array with shape (N, M, 2) after resampling. - - Methods: - convert_bbox: Convert bounding box format. - scale: Scale coordinates by given factors. - denormalize: Convert normalized coordinates to absolute coordinates. - normalize: Convert absolute coordinates to normalized coordinates. - add_padding: Add padding to coordinates. - flipud: Flip coordinates vertically. - fliplr: Flip coordinates horizontally. - clip: Clip coordinates to stay within image boundaries. - remove_zero_area_boxes: Remove boxes with zero area. - update: Update instance variables. - concatenate: Concatenate multiple Instances objects. - - Examples: - Create instances with bounding boxes and segments - >>> instances = Instances( - ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]), - ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])], - ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]), - ... ) - """ - - def __init__( - self, - bboxes: np.ndarray, - segments: np.ndarray = None, - keypoints: np.ndarray = None, - bbox_format: str = "xywh", - normalized: bool = True, - ) -> None: - """ - Initialize the Instances object with bounding boxes, segments, and keypoints. - - Args: - bboxes (np.ndarray): Bounding boxes with shape (N, 4). - segments (np.ndarray, optional): Segmentation masks. - keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible). - bbox_format (str): Format of bboxes. - normalized (bool): Whether the coordinates are normalized. - """ - self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format) - self.keypoints = keypoints - self.normalized = normalized - self.segments = segments - - def convert_bbox(self, format: str) -> None: - """ - Convert bounding box format. - - Args: - format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'. - """ - self._bboxes.convert(format=format) - - @property - def bbox_areas(self) -> np.ndarray: - """Calculate the area of bounding boxes.""" - return self._bboxes.areas() - - def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False): - """ - Scale coordinates by given factors. - - Args: - scale_w (float): Scale factor for width. - scale_h (float): Scale factor for height. - bbox_only (bool, optional): Whether to scale only bounding boxes. - """ - self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h)) - if bbox_only: - return - self.segments[..., 0] *= scale_w - self.segments[..., 1] *= scale_h - if self.keypoints is not None: - self.keypoints[..., 0] *= scale_w - self.keypoints[..., 1] *= scale_h - - def denormalize(self, w: int, h: int) -> None: - """ - Convert normalized coordinates to absolute coordinates. - - Args: - w (int): Image width. - h (int): Image height. - """ - if not self.normalized: - return - self._bboxes.mul(scale=(w, h, w, h)) - self.segments[..., 0] *= w - self.segments[..., 1] *= h - if self.keypoints is not None: - self.keypoints[..., 0] *= w - self.keypoints[..., 1] *= h - self.normalized = False - - def normalize(self, w: int, h: int) -> None: - """ - Convert absolute coordinates to normalized coordinates. - - Args: - w (int): Image width. - h (int): Image height. - """ - if self.normalized: - return - self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h)) - self.segments[..., 0] /= w - self.segments[..., 1] /= h - if self.keypoints is not None: - self.keypoints[..., 0] /= w - self.keypoints[..., 1] /= h - self.normalized = True - - def add_padding(self, padw: int, padh: int) -> None: - """ - Add padding to coordinates. - - Args: - padw (int): Padding width. - padh (int): Padding height. - """ - assert not self.normalized, "you should add padding with absolute coordinates." - self._bboxes.add(offset=(padw, padh, padw, padh)) - self.segments[..., 0] += padw - self.segments[..., 1] += padh - if self.keypoints is not None: - self.keypoints[..., 0] += padw - self.keypoints[..., 1] += padh - - def __getitem__(self, index: Union[int, np.ndarray, slice]) -> "Instances": - """ - Retrieve a specific instance or a set of instances using indexing. - - Args: - index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances. - - Returns: - (Instances): A new Instances object containing the selected boxes, segments, and keypoints if present. - - Notes: - When using boolean indexing, make sure to provide a boolean array with the same length as the number of - instances. - """ - segments = self.segments[index] if len(self.segments) else self.segments - keypoints = self.keypoints[index] if self.keypoints is not None else None - bboxes = self.bboxes[index] - bbox_format = self._bboxes.format - return Instances( - bboxes=bboxes, - segments=segments, - keypoints=keypoints, - bbox_format=bbox_format, - normalized=self.normalized, - ) - - def flipud(self, h: int) -> None: - """ - Flip coordinates vertically. - - Args: - h (int): Image height. - """ - if self._bboxes.format == "xyxy": - y1 = self.bboxes[:, 1].copy() - y2 = self.bboxes[:, 3].copy() - self.bboxes[:, 1] = h - y2 - self.bboxes[:, 3] = h - y1 - else: - self.bboxes[:, 1] = h - self.bboxes[:, 1] - self.segments[..., 1] = h - self.segments[..., 1] - if self.keypoints is not None: - self.keypoints[..., 1] = h - self.keypoints[..., 1] - - def fliplr(self, w: int) -> None: - """ - Flip coordinates horizontally. - - Args: - w (int): Image width. - """ - if self._bboxes.format == "xyxy": - x1 = self.bboxes[:, 0].copy() - x2 = self.bboxes[:, 2].copy() - self.bboxes[:, 0] = w - x2 - self.bboxes[:, 2] = w - x1 - else: - self.bboxes[:, 0] = w - self.bboxes[:, 0] - self.segments[..., 0] = w - self.segments[..., 0] - if self.keypoints is not None: - self.keypoints[..., 0] = w - self.keypoints[..., 0] - - def clip(self, w: int, h: int) -> None: - """ - Clip coordinates to stay within image boundaries. - - Args: - w (int): Image width. - h (int): Image height. - """ - ori_format = self._bboxes.format - self.convert_bbox(format="xyxy") - self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w) - self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h) - if ori_format != "xyxy": - self.convert_bbox(format=ori_format) - self.segments[..., 0] = self.segments[..., 0].clip(0, w) - self.segments[..., 1] = self.segments[..., 1].clip(0, h) - if self.keypoints is not None: - # Set out of bounds visibility to zero - self.keypoints[..., 2][ - (self.keypoints[..., 0] < 0) - | (self.keypoints[..., 0] > w) - | (self.keypoints[..., 1] < 0) - | (self.keypoints[..., 1] > h) - ] = 0.0 - self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w) - self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h) - - def remove_zero_area_boxes(self) -> np.ndarray: - """ - Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height. - - Returns: - (np.ndarray): Boolean array indicating which boxes were kept. - """ - good = self.bbox_areas > 0 - if not all(good): - self._bboxes = self._bboxes[good] - if len(self.segments): - self.segments = self.segments[good] - if self.keypoints is not None: - self.keypoints = self.keypoints[good] - return good - - def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None): - """ - Update instance variables. - - Args: - bboxes (np.ndarray): New bounding boxes. - segments (np.ndarray, optional): New segments. - keypoints (np.ndarray, optional): New keypoints. - """ - self._bboxes = Bboxes(bboxes, format=self._bboxes.format) - if segments is not None: - self.segments = segments - if keypoints is not None: - self.keypoints = keypoints - - def __len__(self) -> int: - """Return the number of instances.""" - return len(self.bboxes) - - @classmethod - def concatenate(cls, instances_list: List["Instances"], axis=0) -> "Instances": - """ - Concatenate a list of Instances objects into a single Instances object. - - Args: - instances_list (List[Instances]): A list of Instances objects to concatenate. - axis (int, optional): The axis along which the arrays will be concatenated. - - Returns: - (Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints - if present. - - Notes: - The `Instances` objects in the list should have the same properties, such as the format of the bounding - boxes, whether keypoints are present, and if the coordinates are normalized. - """ - assert isinstance(instances_list, (list, tuple)) - if not instances_list: - return cls(np.empty(0)) - assert all(isinstance(instance, Instances) for instance in instances_list) - - if len(instances_list) == 1: - return instances_list[0] - - use_keypoint = instances_list[0].keypoints is not None - bbox_format = instances_list[0]._bboxes.format - normalized = instances_list[0].normalized - - cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis) - seg_len = [b.segments.shape[1] for b in instances_list] - if len(frozenset(seg_len)) > 1: # resample segments if there's different length - max_len = max(seg_len) - cat_segments = np.concatenate( - [ - resample_segments(list(b.segments), max_len) - if len(b.segments) - else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments - for b in instances_list - ], - axis=axis, - ) - else: - cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis) - cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None - return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized) - - @property - def bboxes(self) -> np.ndarray: - """Return bounding boxes.""" - return self._bboxes.bboxes diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/loss.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/loss.py deleted file mode 100644 index 2be34e2..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/loss.py +++ /dev/null @@ -1,850 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from typing import Any, Dict, List, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from ultralytics.utils.metrics import OKS_SIGMA -from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh -from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors -from ultralytics.utils.torch_utils import autocast - -from .metrics import bbox_iou, probiou -from .tal import bbox2dist - - -class VarifocalLoss(nn.Module): - """ - Varifocal loss by Zhang et al. - - Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on - hard-to-classify examples and balancing positive/negative samples. - - Attributes: - gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples. - alpha (float): The balancing factor used to address class imbalance. - - References: - https://arxiv.org/abs/2008.13367 - """ - - def __init__(self, gamma: float = 2.0, alpha: float = 0.75): - """Initialize the VarifocalLoss class with focusing and balancing parameters.""" - super().__init__() - self.gamma = gamma - self.alpha = alpha - - def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor: - """Compute varifocal loss between predictions and ground truth.""" - weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label - with autocast(enabled=False): - loss = ( - (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight) - .mean(1) - .sum() - ) - return loss - - -class FocalLoss(nn.Module): - """ - Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5). - - Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing - on hard negatives during training. - - Attributes: - gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples. - alpha (torch.Tensor): The balancing factor used to address class imbalance. - """ - - def __init__(self, gamma: float = 1.5, alpha: float = 0.25): - """Initialize FocalLoss class with focusing and balancing parameters.""" - super().__init__() - self.gamma = gamma - self.alpha = torch.tensor(alpha) - - def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor: - """Calculate focal loss with modulating factors for class imbalance.""" - loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none") - # p_t = torch.exp(-loss) - # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability - - # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py - pred_prob = pred.sigmoid() # prob from logits - p_t = label * pred_prob + (1 - label) * (1 - pred_prob) - modulating_factor = (1.0 - p_t) ** self.gamma - loss *= modulating_factor - if (self.alpha > 0).any(): - self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype) - alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha) - loss *= alpha_factor - return loss.mean(1).sum() - - -class DFLoss(nn.Module): - """Criterion class for computing Distribution Focal Loss (DFL).""" - - def __init__(self, reg_max: int = 16) -> None: - """Initialize the DFL module with regularization maximum.""" - super().__init__() - self.reg_max = reg_max - - def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391.""" - target = target.clamp_(0, self.reg_max - 1 - 0.01) - tl = target.long() # target left - tr = tl + 1 # target right - wl = tr - target # weight left - wr = 1 - wl # weight right - return ( - F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl - + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr - ).mean(-1, keepdim=True) - - -class BboxLoss(nn.Module): - """Criterion class for computing training losses for bounding boxes.""" - - def __init__(self, reg_max: int = 16): - """Initialize the BboxLoss module with regularization maximum and DFL settings.""" - super().__init__() - self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None - - def forward( - self, - pred_dist: torch.Tensor, - pred_bboxes: torch.Tensor, - anchor_points: torch.Tensor, - target_bboxes: torch.Tensor, - target_scores: torch.Tensor, - target_scores_sum: torch.Tensor, - fg_mask: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute IoU and DFL losses for bounding boxes.""" - weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) - iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True) - loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum - - # DFL loss - if self.dfl_loss: - target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1) - loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight - loss_dfl = loss_dfl.sum() / target_scores_sum - else: - loss_dfl = torch.tensor(0.0).to(pred_dist.device) - - return loss_iou, loss_dfl - - -class RotatedBboxLoss(BboxLoss): - """Criterion class for computing training losses for rotated bounding boxes.""" - - def __init__(self, reg_max: int): - """Initialize the RotatedBboxLoss module with regularization maximum and DFL settings.""" - super().__init__(reg_max) - - def forward( - self, - pred_dist: torch.Tensor, - pred_bboxes: torch.Tensor, - anchor_points: torch.Tensor, - target_bboxes: torch.Tensor, - target_scores: torch.Tensor, - target_scores_sum: torch.Tensor, - fg_mask: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute IoU and DFL losses for rotated bounding boxes.""" - weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1) - iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) - loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum - - # DFL loss - if self.dfl_loss: - target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1) - loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight - loss_dfl = loss_dfl.sum() / target_scores_sum - else: - loss_dfl = torch.tensor(0.0).to(pred_dist.device) - - return loss_iou, loss_dfl - - -class KeypointLoss(nn.Module): - """Criterion class for computing keypoint losses.""" - - def __init__(self, sigmas: torch.Tensor) -> None: - """Initialize the KeypointLoss class with keypoint sigmas.""" - super().__init__() - self.sigmas = sigmas - - def forward( - self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor - ) -> torch.Tensor: - """Calculate keypoint loss factor and Euclidean distance loss for keypoints.""" - d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2) - kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9) - # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula - e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval - return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean() - - -class v8DetectionLoss: - """Criterion class for computing training losses for YOLOv8 object detection.""" - - def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled - """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings.""" - device = next(model.parameters()).device # get model device - h = model.args # hyperparameters - - m = model.model[-1] # Detect() module - self.bce = nn.BCEWithLogitsLoss(reduction="none") - self.hyp = h - self.stride = m.stride # model strides - self.nc = m.nc # number of classes - self.no = m.nc + m.reg_max * 4 - self.reg_max = m.reg_max - self.device = device - - self.use_dfl = m.reg_max > 1 - - self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0) - self.bbox_loss = BboxLoss(m.reg_max).to(device) - self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device) - - def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor: - """Preprocess targets by converting to tensor format and scaling coordinates.""" - nl, ne = targets.shape - if nl == 0: - out = torch.zeros(batch_size, 0, ne - 1, device=self.device) - else: - i = targets[:, 0] # image index - _, counts = i.unique(return_counts=True) - counts = counts.to(dtype=torch.int32) - out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device) - for j in range(batch_size): - matches = i == j - if n := matches.sum(): - out[j, :n] = targets[matches, 1:] - out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) - return out - - def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor: - """Decode predicted object bounding box coordinates from anchor points and distribution.""" - if self.use_dfl: - b, a, c = pred_dist.shape # batch, anchors, channels - pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) - # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype)) - # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2) - return dist2bbox(pred_dist, anchor_points, xywh=False) - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" - loss = torch.zeros(3, device=self.device) # box, cls, dfl - feats = preds[1] if isinstance(preds, tuple) else preds - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1 - ) - - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - batch_size = pred_scores.shape[0] - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # Targets - targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1) - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) - - # Pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) - # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1) - # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2 - - _, target_bboxes, target_scores, fg_mask, _ = self.assigner( - # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2, - pred_scores.detach().sigmoid(), - (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, - gt_labels, - gt_bboxes, - mask_gt, - ) - - target_scores_sum = max(target_scores.sum(), 1) - - # Cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - # Bbox loss - if fg_mask.sum(): - target_bboxes /= stride_tensor - loss[0], loss[2] = self.bbox_loss( - pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask - ) - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.cls # cls gain - loss[2] *= self.hyp.dfl # dfl gain - - return loss * batch_size, loss.detach() # loss(box, cls, dfl) - - -class v8SegmentationLoss(v8DetectionLoss): - """Criterion class for computing training losses for YOLOv8 segmentation.""" - - def __init__(self, model): # model must be de-paralleled - """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting.""" - super().__init__(model) - self.overlap = model.args.overlap_mask - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate and return the combined loss for detection and segmentation.""" - loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl - feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] - batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1 - ) - - # B, grids, .. - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - pred_masks = pred_masks.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # Targets - try: - batch_idx = batch["batch_idx"].view(-1, 1) - targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) - except RuntimeError as e: - raise TypeError( - "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n" - "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, " - "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a " - "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' " - "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help." - ) from e - - # Pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) - - _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), - (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, - gt_labels, - gt_bboxes, - mask_gt, - ) - - target_scores_sum = max(target_scores.sum(), 1) - - # Cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - if fg_mask.sum(): - # Bbox loss - loss[0], loss[3] = self.bbox_loss( - pred_distri, - pred_bboxes, - anchor_points, - target_bboxes / stride_tensor, - target_scores, - target_scores_sum, - fg_mask, - ) - # Masks loss - masks = batch["masks"].to(self.device).float() - if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample - masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0] - - loss[1] = self.calculate_segmentation_loss( - fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap - ) - - # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove - else: - loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.box # seg gain - loss[2] *= self.hyp.cls # cls gain - loss[3] *= self.hyp.dfl # dfl gain - - return loss * batch_size, loss.detach() # loss(box, cls, dfl) - - @staticmethod - def single_mask_loss( - gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor - ) -> torch.Tensor: - """ - Compute the instance segmentation loss for a single image. - - Args: - gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects. - pred (torch.Tensor): Predicted mask coefficients of shape (N, 32). - proto (torch.Tensor): Prototype masks of shape (32, H, W). - xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4). - area (torch.Tensor): Area of each ground truth bounding box of shape (N,). - - Returns: - (torch.Tensor): The calculated mask loss for a single image. - - Notes: - The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the - predicted masks from the prototype masks and predicted mask coefficients. - """ - pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80) - loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none") - return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum() - - def calculate_segmentation_loss( - self, - fg_mask: torch.Tensor, - masks: torch.Tensor, - target_gt_idx: torch.Tensor, - target_bboxes: torch.Tensor, - batch_idx: torch.Tensor, - proto: torch.Tensor, - pred_masks: torch.Tensor, - imgsz: torch.Tensor, - overlap: bool, - ) -> torch.Tensor: - """ - Calculate the loss for instance segmentation. - - Args: - fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive. - masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W). - target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors). - target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4). - batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1). - proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W). - pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32). - imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W). - overlap (bool): Whether the masks in `masks` tensor overlap. - - Returns: - (torch.Tensor): The calculated loss for instance segmentation. - - Notes: - The batch loss can be computed for improved speed at higher memory usage. - For example, pred_mask can be computed as follows: - pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160) - """ - _, _, mask_h, mask_w = proto.shape - loss = 0 - - # Normalize to 0-1 - target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]] - - # Areas of target bboxes - marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2) - - # Normalize to mask size - mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device) - - for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)): - fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i - if fg_mask_i.any(): - mask_idx = target_gt_idx_i[fg_mask_i] - if overlap: - gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1) - gt_mask = gt_mask.float() - else: - gt_mask = masks[batch_idx.view(-1) == i][mask_idx] - - loss += self.single_mask_loss( - gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i] - ) - - # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove - else: - loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss - - return loss / fg_mask.sum() - - -class v8PoseLoss(v8DetectionLoss): - """Criterion class for computing training losses for YOLOv8 pose estimation.""" - - def __init__(self, model): # model must be de-paralleled - """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions.""" - super().__init__(model) - self.kpt_shape = model.model[-1].kpt_shape - self.bce_pose = nn.BCEWithLogitsLoss() - is_pose = self.kpt_shape == [17, 3] - nkpt = self.kpt_shape[0] # number of keypoints - sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt - self.keypoint_loss = KeypointLoss(sigmas=sigmas) - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate the total loss and detach it for pose estimation.""" - loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility - feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1] - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1 - ) - - # B, grids, .. - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - pred_kpts = pred_kpts.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # Targets - batch_size = pred_scores.shape[0] - batch_idx = batch["batch_idx"].view(-1, 1) - targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1) - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) - - # Pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4) - pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3) - - _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner( - pred_scores.detach().sigmoid(), - (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), - anchor_points * stride_tensor, - gt_labels, - gt_bboxes, - mask_gt, - ) - - target_scores_sum = max(target_scores.sum(), 1) - - # Cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - # Bbox loss - if fg_mask.sum(): - target_bboxes /= stride_tensor - loss[0], loss[4] = self.bbox_loss( - pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask - ) - keypoints = batch["keypoints"].to(self.device).float().clone() - keypoints[..., 0] *= imgsz[1] - keypoints[..., 1] *= imgsz[0] - - loss[1], loss[2] = self.calculate_keypoints_loss( - fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts - ) - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.pose # pose gain - loss[2] *= self.hyp.kobj # kobj gain - loss[3] *= self.hyp.cls # cls gain - loss[4] *= self.hyp.dfl # dfl gain - - return loss * batch_size, loss.detach() # loss(box, cls, dfl) - - @staticmethod - def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor: - """Decode predicted keypoints to image coordinates.""" - y = pred_kpts.clone() - y[..., :2] *= 2.0 - y[..., 0] += anchor_points[:, [0]] - 0.5 - y[..., 1] += anchor_points[:, [1]] - 0.5 - return y - - def calculate_keypoints_loss( - self, - masks: torch.Tensor, - target_gt_idx: torch.Tensor, - keypoints: torch.Tensor, - batch_idx: torch.Tensor, - stride_tensor: torch.Tensor, - target_bboxes: torch.Tensor, - pred_kpts: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Calculate the keypoints loss for the model. - - This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is - based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is - a binary classification loss that classifies whether a keypoint is present or not. - - Args: - masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors). - target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors). - keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim). - batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1). - stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1). - target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4). - pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim). - - Returns: - kpts_loss (torch.Tensor): The keypoints loss. - kpts_obj_loss (torch.Tensor): The keypoints object loss. - """ - batch_idx = batch_idx.flatten() - batch_size = len(masks) - - # Find the maximum number of keypoints in a single image - max_kpts = torch.unique(batch_idx, return_counts=True)[1].max() - - # Create a tensor to hold batched keypoints - batched_keypoints = torch.zeros( - (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device - ) - - # TODO: any idea how to vectorize this? - # Fill batched_keypoints with keypoints based on batch_idx - for i in range(batch_size): - keypoints_i = keypoints[batch_idx == i] - batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i - - # Expand dimensions of target_gt_idx to match the shape of batched_keypoints - target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1) - - # Use target_gt_idx_expanded to select keypoints from batched_keypoints - selected_keypoints = batched_keypoints.gather( - 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2]) - ) - - # Divide coordinates by stride - selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1) - - kpts_loss = 0 - kpts_obj_loss = 0 - - if masks.any(): - gt_kpt = selected_keypoints[masks] - area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True) - pred_kpt = pred_kpts[masks] - kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True) - kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss - - if pred_kpt.shape[-1] == 3: - kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss - - return kpts_loss, kpts_obj_loss - - -class v8ClassificationLoss: - """Criterion class for computing training losses for classification.""" - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute the classification loss between predictions and true labels.""" - preds = preds[1] if isinstance(preds, (list, tuple)) else preds - loss = F.cross_entropy(preds, batch["cls"], reduction="mean") - return loss, loss.detach() - - -class v8OBBLoss(v8DetectionLoss): - """Calculates losses for object detection, classification, and box distribution in rotated YOLO models.""" - - def __init__(self, model): - """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled.""" - super().__init__(model) - self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0) - self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device) - - def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor: - """Preprocess targets for oriented bounding box detection.""" - if targets.shape[0] == 0: - out = torch.zeros(batch_size, 0, 6, device=self.device) - else: - i = targets[:, 0] # image index - _, counts = i.unique(return_counts=True) - counts = counts.to(dtype=torch.int32) - out = torch.zeros(batch_size, counts.max(), 6, device=self.device) - for j in range(batch_size): - matches = i == j - if n := matches.sum(): - bboxes = targets[matches, 2:] - bboxes[..., :4].mul_(scale_tensor) - out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1) - return out - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate and return the loss for oriented bounding box detection.""" - loss = torch.zeros(3, device=self.device) # box, cls, dfl - feats, pred_angle = preds if isinstance(preds[0], list) else preds[1] - batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width - pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split( - (self.reg_max * 4, self.nc), 1 - ) - - # b, grids, .. - pred_scores = pred_scores.permute(0, 2, 1).contiguous() - pred_distri = pred_distri.permute(0, 2, 1).contiguous() - pred_angle = pred_angle.permute(0, 2, 1).contiguous() - - dtype = pred_scores.dtype - imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w) - anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5) - - # targets - try: - batch_idx = batch["batch_idx"].view(-1, 1) - targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1) - rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item() - targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training - targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]]) - gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr - mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0) - except RuntimeError as e: - raise TypeError( - "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n" - "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, " - "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a " - "correctly formatted 'OBB' dataset using 'data=dota8.yaml' " - "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help." - ) from e - - # Pboxes - pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4) - - bboxes_for_assigner = pred_bboxes.clone().detach() - # Only the first four elements need to be scaled - bboxes_for_assigner[..., :4] *= stride_tensor - _, target_bboxes, target_scores, fg_mask, _ = self.assigner( - pred_scores.detach().sigmoid(), - bboxes_for_assigner.type(gt_bboxes.dtype), - anchor_points * stride_tensor, - gt_labels, - gt_bboxes, - mask_gt, - ) - - target_scores_sum = max(target_scores.sum(), 1) - - # Cls loss - # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way - loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE - - # Bbox loss - if fg_mask.sum(): - target_bboxes[..., :4] /= stride_tensor - loss[0], loss[2] = self.bbox_loss( - pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask - ) - else: - loss[0] += (pred_angle * 0).sum() - - loss[0] *= self.hyp.box # box gain - loss[1] *= self.hyp.cls # cls gain - loss[2] *= self.hyp.dfl # dfl gain - - return loss * batch_size, loss.detach() # loss(box, cls, dfl) - - def bbox_decode( - self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor - ) -> torch.Tensor: - """ - Decode predicted object bounding box coordinates from anchor points and distribution. - - Args: - anchor_points (torch.Tensor): Anchor points, (h*w, 2). - pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4). - pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1). - - Returns: - (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5). - """ - if self.use_dfl: - b, a, c = pred_dist.shape # batch, anchors, channels - pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype)) - return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1) - - -class E2EDetectLoss: - """Criterion class for computing training losses for end-to-end detection.""" - - def __init__(self, model): - """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model.""" - self.one2many = v8DetectionLoss(model, tal_topk=10) - self.one2one = v8DetectionLoss(model, tal_topk=1) - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate the sum of the loss for box, cls and dfl multiplied by batch size.""" - preds = preds[1] if isinstance(preds, tuple) else preds - one2many = preds["one2many"] - loss_one2many = self.one2many(one2many, batch) - one2one = preds["one2one"] - loss_one2one = self.one2one(one2one, batch) - return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1] - - -class TVPDetectLoss: - """Criterion class for computing training losses for text-visual prompt detection.""" - - def __init__(self, model): - """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model.""" - self.vp_criterion = v8DetectionLoss(model) - # NOTE: store following info as it's changeable in __call__ - self.ori_nc = self.vp_criterion.nc - self.ori_no = self.vp_criterion.no - self.ori_reg_max = self.vp_criterion.reg_max - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate the loss for text-visual prompt detection.""" - feats = preds[1] if isinstance(preds, tuple) else preds - assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it - - if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]: - loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True) - return loss, loss.detach() - - vp_feats = self._get_vp_features(feats) - vp_loss = self.vp_criterion(vp_feats, batch) - box_loss = vp_loss[0][1] - return box_loss, vp_loss[1] - - def _get_vp_features(self, feats: List[torch.Tensor]) -> List[torch.Tensor]: - """Extract visual-prompt features from the model output.""" - vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc - - self.vp_criterion.nc = vnc - self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4 - self.vp_criterion.assigner.num_classes = vnc - - return [ - torch.cat((box, cls_vp), dim=1) - for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats] - ] - - -class TVPSegmentLoss(TVPDetectLoss): - """Criterion class for computing training losses for text-visual prompt segmentation.""" - - def __init__(self, model): - """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model.""" - super().__init__(model) - self.vp_criterion = v8SegmentationLoss(model) - - def __call__(self, preds: Any, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculate the loss for text-visual prompt segmentation.""" - feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] - assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it - - if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]: - loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True) - return loss, loss.detach() - - vp_feats = self._get_vp_features(feats) - vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch) - cls_loss = vp_loss[0][2] - return cls_loss, vp_loss[1] diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/metrics.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/metrics.py deleted file mode 100644 index ec29420..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/metrics.py +++ /dev/null @@ -1,1590 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -"""Model validation metrics.""" - -import math -import warnings -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple, Union - -import numpy as np -import torch - -from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings - -OKS_SIGMA = ( - np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89]) - / 10.0 -) - - -def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray: - """ - Calculate the intersection over box2 area given box1 and box2. - - Args: - box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format. - box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format. - iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area. - eps (float, optional): A small value to avoid division by zero. - - Returns: - (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area. - """ - # Get the coordinates of bounding boxes - b1_x1, b1_y1, b1_x2, b1_y2 = box1.T - b2_x1, b2_y1, b2_x2, b2_y2 = box2.T - - # Intersection area - inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * ( - np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1) - ).clip(0) - - # Box2 area - area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1) - if iou: - box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1) - area = area + box1_area[:, None] - inter_area - - # Intersection over box2 area - return inter_area / (area + eps) - - -def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - """ - Calculate intersection-over-union (IoU) of boxes. - - Args: - box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format. - box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format. - eps (float, optional): A small value to avoid division by zero. - - Returns: - (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2. - - References: - https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py - """ - # NOTE: Need .float() to get accurate iou values - # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) - (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2) - inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2) - - # IoU = inter / (area1 + area2 - inter) - return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps) - - -def bbox_iou( - box1: torch.Tensor, - box2: torch.Tensor, - xywh: bool = True, - GIoU: bool = False, - DIoU: bool = False, - CIoU: bool = False, - eps: float = 1e-7, -) -> torch.Tensor: - """ - Calculate the Intersection over Union (IoU) between bounding boxes. - - This function supports various shapes for `box1` and `box2` as long as the last dimension is 4. - For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4). - Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`, - or (x1, y1, x2, y2) if `xywh=False`. - - Args: - box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4. - box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4. - xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in - (x1, y1, x2, y2) format. - GIoU (bool, optional): If True, calculate Generalized IoU. - DIoU (bool, optional): If True, calculate Distance IoU. - CIoU (bool, optional): If True, calculate Complete IoU. - eps (float, optional): A small value to avoid division by zero. - - Returns: - (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags. - """ - # Get the coordinates of bounding boxes - if xywh: # transform from xywh to xyxy - (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1) - w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2 - b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_ - b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_ - else: # x1, y1, x2, y2 = box1 - b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) - b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) - w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps - w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps - - # Intersection area - inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * ( - b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) - ).clamp_(0) - - # Union Area - union = w1 * h1 + w2 * h2 - inter + eps - - # IoU - iou = inter / union - if CIoU or DIoU or GIoU: - cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width - ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height - if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 - c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared - rho2 = ( - (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2) - ) / 4 # center dist**2 - if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 - v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) - with torch.no_grad(): - alpha = v / (v - iou + (1 + eps)) - return iou - (rho2 / c2 + v * alpha) # CIoU - return iou - rho2 / c2 # DIoU - c_area = cw * ch + eps # convex area - return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf - return iou # IoU - - -def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: - """ - Calculate masks IoU. - - Args: - mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the - product of image width and height. - mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the - product of image width and height. - eps (float, optional): A small value to avoid division by zero. - - Returns: - (torch.Tensor): A tensor of shape (N, M) representing masks IoU. - """ - intersection = torch.matmul(mask1, mask2.T).clamp_(0) - union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection - return intersection / (union + eps) - - -def kpt_iou( - kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: List[float], eps: float = 1e-7 -) -> torch.Tensor: - """ - Calculate Object Keypoint Similarity (OKS). - - Args: - kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints. - kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints. - area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth. - sigma (list): A list containing 17 values representing keypoint scales. - eps (float, optional): A small value to avoid division by zero. - - Returns: - (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities. - """ - d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17) - sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, ) - kpt_mask = kpt1[..., 2] != 0 # (N, 17) - e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval - # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula - return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps) - - -def _get_covariance_matrix(boxes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Generate covariance matrix from oriented bounding boxes. - - Args: - boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format. - - Returns: - (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes. - """ - # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here. - gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1) - a, b, c = gbbs.split(1, dim=-1) - cos = c.cos() - sin = c.sin() - cos2 = cos.pow(2) - sin2 = sin.pow(2) - return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin - - -def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor: - """ - Calculate probabilistic IoU between oriented bounding boxes. - - Args: - obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr. - obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr. - CIoU (bool, optional): If True, calculate CIoU. - eps (float, optional): Small value to avoid division by zero. - - Returns: - (torch.Tensor): OBB similarities, shape (N,). - - Notes: - OBB format: [center_x, center_y, width, height, rotation_angle]. - - References: - https://arxiv.org/pdf/2106.06072v1.pdf - """ - x1, y1 = obb1[..., :2].split(1, dim=-1) - x2, y2 = obb2[..., :2].split(1, dim=-1) - a1, b1, c1 = _get_covariance_matrix(obb1) - a2, b2, c2 = _get_covariance_matrix(obb2) - - t1 = ( - ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) - ) * 0.25 - t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 - t3 = ( - ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) - / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) - + eps - ).log() * 0.5 - bd = (t1 + t2 + t3).clamp(eps, 100.0) - hd = (1.0 - (-bd).exp() + eps).sqrt() - iou = 1 - hd - if CIoU: # only include the wh aspect ratio part - w1, h1 = obb1[..., 2:4].split(1, dim=-1) - w2, h2 = obb2[..., 2:4].split(1, dim=-1) - v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2) - with torch.no_grad(): - alpha = v / (v - iou + (1 + eps)) - return iou - v * alpha # CIoU - return iou - - -def batch_probiou( - obb1: Union[torch.Tensor, np.ndarray], obb2: Union[torch.Tensor, np.ndarray], eps: float = 1e-7 -) -> torch.Tensor: - """ - Calculate the probabilistic IoU between oriented bounding boxes. - - Args: - obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format. - obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format. - eps (float, optional): A small value to avoid division by zero. - - Returns: - (torch.Tensor): A tensor of shape (N, M) representing obb similarities. - - References: - https://arxiv.org/pdf/2106.06072v1.pdf - """ - obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1 - obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2 - - x1, y1 = obb1[..., :2].split(1, dim=-1) - x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1)) - a1, b1, c1 = _get_covariance_matrix(obb1) - a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2)) - - t1 = ( - ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps) - ) * 0.25 - t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5 - t3 = ( - ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2)) - / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps) - + eps - ).log() * 0.5 - bd = (t1 + t2 + t3).clamp(eps, 100.0) - hd = (1.0 - (-bd).exp() + eps).sqrt() - return 1 - hd - - -def smooth_bce(eps: float = 0.1) -> Tuple[float, float]: - """ - Compute smoothed positive and negative Binary Cross-Entropy targets. - - Args: - eps (float, optional): The epsilon value for label smoothing. - - Returns: - pos (float): Positive label smoothing BCE target. - neg (float): Negative label smoothing BCE target. - - References: - https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 - """ - return 1.0 - 0.5 * eps, 0.5 * eps - - -class ConfusionMatrix(DataExportMixin): - """ - A class for calculating and updating a confusion matrix for object detection and classification tasks. - - Attributes: - task (str): The type of task, either 'detect' or 'classify'. - matrix (np.ndarray): The confusion matrix, with dimensions depending on the task. - nc (int): The number of category. - names (List[str]): The names of the classes, used as labels on the plot. - matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN. - """ - - def __init__(self, names: Dict[int, str] = [], task: str = "detect", save_matches: bool = False): - """ - Initialize a ConfusionMatrix instance. - - Args: - names (Dict[int, str], optional): Names of classes, used as labels on the plot. - task (str, optional): Type of task, either 'detect' or 'classify'. - save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization. - """ - self.task = task - self.nc = len(names) # number of classes - self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1)) - self.names = names # name of classes - self.matches = {} if save_matches else None - - def _append_matches(self, mtype: str, batch: Dict[str, Any], idx: int) -> None: - """ - Append the matches to TP, FP, FN or GT list for the last batch. - - This method updates the matches dictionary by appending specific batch data - to the appropriate match type (True Positive, False Positive, or False Negative). - - Args: - mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT'). - batch (Dict[str, Any]): Batch data containing detection results with keys - like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'. - idx (int): Index of the specific detection to append from the batch. - - Note: - For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0, - it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing. - """ - if self.matches is None: - return - for k, v in batch.items(): - if k in {"bboxes", "cls", "conf", "keypoints"}: - self.matches[mtype][k] += v[[idx]] - elif k == "masks": - # NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape - self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]] - - def process_cls_preds(self, preds: List[torch.Tensor], targets: List[torch.Tensor]) -> None: - """ - Update confusion matrix for classification task. - - Args: - preds (List[N, min(nc,5)]): Predicted class labels. - targets (List[N, 1]): Ground truth class labels. - """ - preds, targets = torch.cat(preds)[:, 0], torch.cat(targets) - for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()): - self.matrix[p][t] += 1 - - def process_batch( - self, - detections: Dict[str, torch.Tensor], - batch: Dict[str, Any], - conf: float = 0.25, - iou_thres: float = 0.45, - ) -> None: - """ - Update confusion matrix for object detection task. - - Args: - detections (Dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information. - Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be - Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle. - batch (Dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and - 'cls' (Array[M]) keys, where M is the number of ground truth objects. - conf (float, optional): Confidence threshold for detections. - iou_thres (float, optional): IoU threshold for matching detections to ground truth. - """ - gt_cls, gt_bboxes = batch["cls"], batch["bboxes"] - if self.matches is not None: # only if visualization is enabled - self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}} - for i in range(len(gt_cls)): - self._append_matches("GT", batch, i) # store GT - is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB - conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed - no_pred = len(detections["cls"]) == 0 - if gt_cls.shape[0] == 0: # Check if labels is empty - if not no_pred: - detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()} - detection_classes = detections["cls"].int().tolist() - for i, dc in enumerate(detection_classes): - self.matrix[dc, self.nc] += 1 # FP - self._append_matches("FP", detections, i) - return - if no_pred: - gt_classes = gt_cls.int().tolist() - for i, gc in enumerate(gt_classes): - self.matrix[self.nc, gc] += 1 # FN - self._append_matches("FN", batch, i) - return - - detections = {k: detections[k][detections["conf"] > conf] for k in detections.keys()} - gt_classes = gt_cls.int().tolist() - detection_classes = detections["cls"].int().tolist() - bboxes = detections["bboxes"] - iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes) - - x = torch.where(iou > iou_thres) - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - else: - matches = np.zeros((0, 3)) - - n = matches.shape[0] > 0 - m0, m1, _ = matches.transpose().astype(int) - for i, gc in enumerate(gt_classes): - j = m0 == i - if n and sum(j) == 1: - dc = detection_classes[m1[j].item()] - self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN - if dc == gc: - self._append_matches("TP", detections, m1[j].item()) - else: - self._append_matches("FP", detections, m1[j].item()) - self._append_matches("FN", batch, i) - else: - self.matrix[self.nc, gc] += 1 # FN - self._append_matches("FN", batch, i) - - for i, dc in enumerate(detection_classes): - if not any(m1 == i): - self.matrix[dc, self.nc] += 1 # FP - self._append_matches("FP", detections, i) - - def matrix(self): - """Return the confusion matrix.""" - return self.matrix - - def tp_fp(self) -> Tuple[np.ndarray, np.ndarray]: - """ - Return true positives and false positives. - - Returns: - tp (np.ndarray): True positives. - fp (np.ndarray): False positives. - """ - tp = self.matrix.diagonal() # true positives - fp = self.matrix.sum(1) - tp # false positives - # fn = self.matrix.sum(0) - tp # false negatives (missed detections) - return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect - - def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None: - """ - Plot grid of GT, TP, FP, FN for each image. - - Args: - img (torch.Tensor): Image to plot onto. - im_file (str): Image filename to save visualizations. - save_dir (Path): Location to save the visualizations to. - """ - if not self.matches: - return - from .ops import xyxy2xywh - from .plotting import plot_images - - # Create batch of 4 (GT, TP, FP, FN) - labels = defaultdict(list) - for i, mtype in enumerate(["GT", "FP", "TP", "FN"]): - mbatch = self.matches[mtype] - if "conf" not in mbatch: - mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device) - mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i - for k in mbatch.keys(): - labels[k] += mbatch[k] - - labels = {k: torch.stack(v, 0) if len(v) else v for k, v in labels.items()} - if not self.task == "obb" and len(labels["bboxes"]): - labels["bboxes"] = xyxy2xywh(labels["bboxes"]) - (save_dir / "visualizations").mkdir(parents=True, exist_ok=True) - plot_images( - labels, - img.repeat(4, 1, 1, 1), - paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"], - fname=save_dir / "visualizations" / Path(im_file).name, - names=self.names, - max_subplots=4, - conf_thres=0.001, - ) - - @TryExcept(msg="ConfusionMatrix plot failure") - @plt_settings() - def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None): - """ - Plot the confusion matrix using matplotlib and save it to a file. - - Args: - normalize (bool, optional): Whether to normalize the confusion matrix. - save_dir (str, optional): Directory where the plot will be saved. - on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered. - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - - array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns - array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) - - fig, ax = plt.subplots(1, 1, figsize=(12, 9)) - names, n = list(self.names.values()), self.nc - if self.nc >= 100: # downsample for large class count - k = max(2, self.nc // 60) # step size for downsampling, always > 1 - keep_idx = slice(None, None, k) # create slice instead of array - names = names[keep_idx] # slice class names - array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols - n = (self.nc + k - 1) // k # number of retained classes - nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed - ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto" - xy_ticks = np.arange(len(ticklabels)) - tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6 - label_fontsize = max(6, 12 - 0.1 * nc) - title_fontsize = max(6, 12 - 0.1 * nc) - btm = max(0.1, 0.25 - 0.001 * nc) # Minimum value is 0.1 - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered - im = ax.imshow(array, cmap="Blues", vmin=0.0, interpolation="none") - ax.xaxis.set_label_position("bottom") - if nc < 30: # Add score for each cell of confusion matrix - color_threshold = 0.45 * (1 if normalize else np.nanmax(array)) # text color threshold - for i, row in enumerate(array[:nc]): - for j, val in enumerate(row[:nc]): - val = array[i, j] - if np.isnan(val): - continue - ax.text( - j, - i, - f"{val:.2f}" if normalize else f"{int(val)}", - ha="center", - va="center", - fontsize=10, - color="white" if val > color_threshold else "black", - ) - cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05) - title = "Confusion Matrix" + " Normalized" * normalize - ax.set_xlabel("True", fontsize=label_fontsize, labelpad=10) - ax.set_ylabel("Predicted", fontsize=label_fontsize, labelpad=10) - ax.set_title(title, fontsize=title_fontsize, pad=20) - ax.set_xticks(xy_ticks) - ax.set_yticks(xy_ticks) - ax.tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False) - ax.tick_params(axis="y", left=True, right=False, labelleft=True, labelright=False) - if ticklabels != "auto": - ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha="center") - ax.set_yticklabels(ticklabels, fontsize=tick_fontsize) - for s in {"left", "right", "bottom", "top", "outline"}: - if s != "outline": - ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline - cbar.ax.spines[s].set_visible(False) - fig.subplots_adjust(left=0, right=0.84, top=0.94, bottom=btm) # Adjust layout to ensure equal margins - plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png" - fig.savefig(plot_fname, dpi=250) - plt.close(fig) - if on_plot: - on_plot(plot_fname) - - def print(self): - """Print the confusion matrix to the console.""" - for i in range(self.matrix.shape[0]): - LOGGER.info(" ".join(map(str, self.matrix[i]))) - - def summary(self, normalize: bool = False, decimals: int = 5) -> List[Dict[str, float]]: - """ - Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional - normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL. - - Args: - normalize (bool): Whether to normalize the confusion matrix values. - decimals (int): Number of decimal places to round the output values to. - - Returns: - (List[Dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes. - - Examples: - >>> results = model.val(data="coco8.yaml", plots=True) - >>> cm_dict = results.confusion_matrix.summary(normalize=True, decimals=5) - >>> print(cm_dict) - """ - import re - - names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"] - clean_names, seen = [], set() - for name in names: - clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name) - original_clean = clean_name - counter = 1 - while clean_name.lower() in seen: - clean_name = f"{original_clean}_{counter}" - counter += 1 - seen.add(clean_name.lower()) - clean_names.append(clean_name) - array = (self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)).round(decimals) - return [ - dict({"Predicted": clean_names[i]}, **{clean_names[j]: array[i, j] for j in range(len(clean_names))}) - for i in range(len(clean_names)) - ] - - -def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray: - """Box filter of fraction f.""" - nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) - p = np.ones(nf // 2) # ones padding - yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded - return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed - - -@plt_settings() -def plot_pr_curve( - px: np.ndarray, - py: np.ndarray, - ap: np.ndarray, - save_dir: Path = Path("pr_curve.png"), - names: Dict[int, str] = {}, - on_plot=None, -): - """ - Plot precision-recall curve. - - Args: - px (np.ndarray): X values for the PR curve. - py (np.ndarray): Y values for the PR curve. - ap (np.ndarray): Average precision values. - save_dir (Path, optional): Path to save the plot. - names (Dict[int, str], optional): Dictionary mapping class indices to class names. - on_plot (callable, optional): Function to call after plot is saved. - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - - fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) - py = np.stack(py, axis=1) - - if 0 < len(names) < 21: # display per-class legend if < 21 classes - for i, y in enumerate(py.T): - ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) - else: - ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) - - ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5") - ax.set_xlabel("Recall") - ax.set_ylabel("Precision") - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - ax.set_title("Precision-Recall Curve") - fig.savefig(save_dir, dpi=250) - plt.close(fig) - if on_plot: - on_plot(save_dir) - - -@plt_settings() -def plot_mc_curve( - px: np.ndarray, - py: np.ndarray, - save_dir: Path = Path("mc_curve.png"), - names: Dict[int, str] = {}, - xlabel: str = "Confidence", - ylabel: str = "Metric", - on_plot=None, -): - """ - Plot metric-confidence curve. - - Args: - px (np.ndarray): X values for the metric-confidence curve. - py (np.ndarray): Y values for the metric-confidence curve. - save_dir (Path, optional): Path to save the plot. - names (Dict[int, str], optional): Dictionary mapping class indices to class names. - xlabel (str, optional): X-axis label. - ylabel (str, optional): Y-axis label. - on_plot (callable, optional): Function to call after plot is saved. - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - - fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) - - if 0 < len(names) < 21: # display per-class legend if < 21 classes - for i, y in enumerate(py): - ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) - else: - ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) - - y = smooth(py.mean(0), 0.1) - ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}") - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") - ax.set_title(f"{ylabel}-Confidence Curve") - fig.savefig(save_dir, dpi=250) - plt.close(fig) - if on_plot: - on_plot(save_dir) - - -def compute_ap(recall: List[float], precision: List[float]) -> Tuple[float, np.ndarray, np.ndarray]: - """ - Compute the average precision (AP) given the recall and precision curves. - - Args: - recall (list): The recall curve. - precision (list): The precision curve. - - Returns: - ap (float): Average precision. - mpre (np.ndarray): Precision envelope curve. - mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end. - """ - # Append sentinel values to beginning and end - mrec = np.concatenate(([0.0], recall, [1.0])) - mpre = np.concatenate(([1.0], precision, [0.0])) - - # Compute the precision envelope - mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) - - # Integrate area under curve - method = "interp" # methods: 'continuous', 'interp' - if method == "interp": - x = np.linspace(0, 1, 101) # 101-point interp (COCO) - func = np.trapezoid if checks.check_version(np.__version__, ">=2.0") else np.trapz # np.trapz deprecated - ap = func(np.interp(x, mrec, mpre), x) # integrate - else: # 'continuous' - i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes - ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve - - return ap, mpre, mrec - - -def ap_per_class( - tp: np.ndarray, - conf: np.ndarray, - pred_cls: np.ndarray, - target_cls: np.ndarray, - plot: bool = False, - on_plot=None, - save_dir: Path = Path(), - names: Dict[int, str] = {}, - eps: float = 1e-16, - prefix: str = "", -) -> Tuple: - """ - Compute the average precision per class for object detection evaluation. - - Args: - tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False). - conf (np.ndarray): Array of confidence scores of the detections. - pred_cls (np.ndarray): Array of predicted classes of the detections. - target_cls (np.ndarray): Array of true classes of the detections. - plot (bool, optional): Whether to plot PR curves or not. - on_plot (callable, optional): A callback to pass plots path and data when they are rendered. - save_dir (Path, optional): Directory to save the PR curves. - names (Dict[int, str], optional): Dictionary of class names to plot PR curves. - eps (float, optional): A small value to avoid division by zero. - prefix (str, optional): A prefix string for saving the plot files. - - Returns: - tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class. - fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class. - p (np.ndarray): Precision values at threshold given by max F1 metric for each class. - r (np.ndarray): Recall values at threshold given by max F1 metric for each class. - f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class. - ap (np.ndarray): Average precision for each class at different IoU thresholds. - unique_classes (np.ndarray): An array of unique classes that have data. - p_curve (np.ndarray): Precision curves for each class. - r_curve (np.ndarray): Recall curves for each class. - f1_curve (np.ndarray): F1-score curves for each class. - x (np.ndarray): X-axis values for the curves. - prec_values (np.ndarray): Precision values at mAP@0.5 for each class. - """ - # Sort by objectness - i = np.argsort(-conf) - tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] - - # Find unique classes - unique_classes, nt = np.unique(target_cls, return_counts=True) - nc = unique_classes.shape[0] # number of classes, number of detections - - # Create Precision-Recall curve and compute AP for each class - x, prec_values = np.linspace(0, 1, 1000), [] - - # Average precision, precision and recall curves - ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) - for ci, c in enumerate(unique_classes): - i = pred_cls == c - n_l = nt[ci] # number of labels - n_p = i.sum() # number of predictions - if n_p == 0 or n_l == 0: - continue - - # Accumulate FPs and TPs - fpc = (1 - tp[i]).cumsum(0) - tpc = tp[i].cumsum(0) - - # Recall - recall = tpc / (n_l + eps) # recall curve - r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases - - # Precision - precision = tpc / (tpc + fpc) # precision curve - p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score - - # AP from recall-precision curve - for j in range(tp.shape[1]): - ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) - if j == 0: - prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5 - - prec_values = np.array(prec_values) if prec_values else np.zeros((1, 1000)) # (nc, 1000) - - # Compute F1 (harmonic mean of precision and recall) - f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps) - names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data - if plot: - plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot) - plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot) - plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot) - plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot) - - i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index - p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values - tp = (r * nt).round() # true positives - fp = (tp / (p + eps) - tp).round() # false positives - return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values - - -class Metric(SimpleClass): - """ - Class for computing evaluation metrics for Ultralytics YOLO models. - - Attributes: - p (list): Precision for each class. Shape: (nc,). - r (list): Recall for each class. Shape: (nc,). - f1 (list): F1 score for each class. Shape: (nc,). - all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10). - ap_class_index (list): Index of class for each AP score. Shape: (nc,). - nc (int): Number of classes. - - Methods: - ap50: AP at IoU threshold of 0.5 for all classes. - ap: AP at IoU thresholds from 0.5 to 0.95 for all classes. - mp: Mean precision of all classes. - mr: Mean recall of all classes. - map50: Mean AP at IoU threshold of 0.5 for all classes. - map75: Mean AP at IoU threshold of 0.75 for all classes. - map: Mean AP at IoU thresholds from 0.5 to 0.95 for all classes. - mean_results: Mean of results, returns mp, mr, map50, map. - class_result: Class-aware result, returns p[i], r[i], ap50[i], ap[i]. - maps: mAP of each class. - fitness: Model fitness as a weighted combination of metrics. - update: Update metric attributes with new evaluation results. - curves: Provides a list of curves for accessing specific metrics like precision, recall, F1, etc. - curves_results: Provide a list of results for accessing specific metrics like precision, recall, F1, etc. - """ - - def __init__(self) -> None: - """Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model.""" - self.p = [] # (nc, ) - self.r = [] # (nc, ) - self.f1 = [] # (nc, ) - self.all_ap = [] # (nc, 10) - self.ap_class_index = [] # (nc, ) - self.nc = 0 - - @property - def ap50(self) -> Union[np.ndarray, List]: - """ - Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes. - - Returns: - (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available. - """ - return self.all_ap[:, 0] if len(self.all_ap) else [] - - @property - def ap(self) -> Union[np.ndarray, List]: - """ - Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes. - - Returns: - (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available. - """ - return self.all_ap.mean(1) if len(self.all_ap) else [] - - @property - def mp(self) -> float: - """ - Return the Mean Precision of all classes. - - Returns: - (float): The mean precision of all classes. - """ - return self.p.mean() if len(self.p) else 0.0 - - @property - def mr(self) -> float: - """ - Return the Mean Recall of all classes. - - Returns: - (float): The mean recall of all classes. - """ - return self.r.mean() if len(self.r) else 0.0 - - @property - def map50(self) -> float: - """ - Return the mean Average Precision (mAP) at an IoU threshold of 0.5. - - Returns: - (float): The mAP at an IoU threshold of 0.5. - """ - return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0 - - @property - def map75(self) -> float: - """ - Return the mean Average Precision (mAP) at an IoU threshold of 0.75. - - Returns: - (float): The mAP at an IoU threshold of 0.75. - """ - return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0 - - @property - def map(self) -> float: - """ - Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05. - - Returns: - (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05. - """ - return self.all_ap.mean() if len(self.all_ap) else 0.0 - - def mean_results(self) -> List[float]: - """Return mean of results, mp, mr, map50, map.""" - return [self.mp, self.mr, self.map50, self.map] - - def class_result(self, i: int) -> Tuple[float, float, float, float]: - """Return class-aware result, p[i], r[i], ap50[i], ap[i].""" - return self.p[i], self.r[i], self.ap50[i], self.ap[i] - - @property - def maps(self) -> np.ndarray: - """Return mAP of each class.""" - maps = np.zeros(self.nc) + self.map - for i, c in enumerate(self.ap_class_index): - maps[c] = self.ap[i] - return maps - - def fitness(self) -> float: - """Return model fitness as a weighted combination of metrics.""" - w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95] - return (np.nan_to_num(np.array(self.mean_results())) * w).sum() - - def update(self, results: tuple): - """ - Update the evaluation metrics with a new set of results. - - Args: - results (tuple): A tuple containing evaluation metrics: - - p (list): Precision for each class. - - r (list): Recall for each class. - - f1 (list): F1 score for each class. - - all_ap (list): AP scores for all classes and all IoU thresholds. - - ap_class_index (list): Index of class for each AP score. - - p_curve (list): Precision curve for each class. - - r_curve (list): Recall curve for each class. - - f1_curve (list): F1 curve for each class. - - px (list): X values for the curves. - - prec_values (list): Precision values for each class. - """ - ( - self.p, - self.r, - self.f1, - self.all_ap, - self.ap_class_index, - self.p_curve, - self.r_curve, - self.f1_curve, - self.px, - self.prec_values, - ) = results - - @property - def curves(self) -> List: - """Return a list of curves for accessing specific metrics curves.""" - return [] - - @property - def curves_results(self) -> List[List]: - """Return a list of curves for accessing specific metrics curves.""" - return [ - [self.px, self.prec_values, "Recall", "Precision"], - [self.px, self.f1_curve, "Confidence", "F1"], - [self.px, self.p_curve, "Confidence", "Precision"], - [self.px, self.r_curve, "Confidence", "Recall"], - ] - - -class DetMetrics(SimpleClass, DataExportMixin): - """ - Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP). - - Attributes: - names (Dict[int, str]): A dictionary of class names. - box (Metric): An instance of the Metric class for storing detection results. - speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process. - task (str): The task type, set to 'detect'. - stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images. - nt_per_class: Number of targets per class. - nt_per_image: Number of targets per image. - - Methods: - update_stats: Update statistics by appending new values to existing stat collections. - process: Process predicted results for object detection and update metrics. - clear_stats: Clear the stored statistics. - keys: Return a list of keys for accessing specific metrics. - mean_results: Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95. - class_result: Return the result of evaluating the performance of an object detection model on a specific class. - maps: Return mean Average Precision (mAP) scores per class. - fitness: Return the fitness of box object. - ap_class_index: Return the average precision index per class. - results_dict: Return dictionary of computed performance metrics and statistics. - curves: Return a list of curves for accessing specific metrics curves. - curves_results: Return a list of computed performance metrics and statistics. - summary: Generate a summarized representation of per-class detection metrics as a list of dictionaries. - """ - - def __init__(self, names: Dict[int, str] = {}) -> None: - """ - Initialize a DetMetrics instance with a save directory, plot flag, and class names. - - Args: - names (Dict[int, str], optional): Dictionary of class names. - """ - self.names = names - self.box = Metric() - self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} - self.task = "detect" - self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[]) - self.nt_per_class = None - self.nt_per_image = None - - def update_stats(self, stat: Dict[str, Any]) -> None: - """ - Update statistics by appending new values to existing stat collections. - - Args: - stat (Dict[str, any]): Dictionary containing new statistical values to append. - Keys should match existing keys in self.stats. - """ - for k in self.stats.keys(): - self.stats[k].append(stat[k]) - - def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]: - """ - Process predicted results for object detection and update metrics. - - Args: - save_dir (Path): Directory to save plots. Defaults to Path("."). - plot (bool): Whether to plot precision-recall curves. Defaults to False. - on_plot (callable, optional): Function to call after plots are generated. Defaults to None. - - Returns: - (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays. - """ - stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy - if len(stats) == 0: - return stats - results = ap_per_class( - stats["tp"], - stats["conf"], - stats["pred_cls"], - stats["target_cls"], - plot=plot, - save_dir=save_dir, - names=self.names, - on_plot=on_plot, - prefix="Box", - )[2:] - self.box.nc = len(self.names) - self.box.update(results) - self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names)) - self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names)) - return stats - - def clear_stats(self): - """Clear the stored statistics.""" - for v in self.stats.values(): - v.clear() - - @property - def keys(self) -> List[str]: - """Return a list of keys for accessing specific metrics.""" - return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"] - - def mean_results(self) -> List[float]: - """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.""" - return self.box.mean_results() - - def class_result(self, i: int) -> Tuple[float, float, float, float]: - """Return the result of evaluating the performance of an object detection model on a specific class.""" - return self.box.class_result(i) - - @property - def maps(self) -> np.ndarray: - """Return mean Average Precision (mAP) scores per class.""" - return self.box.maps - - @property - def fitness(self) -> float: - """Return the fitness of box object.""" - return self.box.fitness() - - @property - def ap_class_index(self) -> List: - """Return the average precision index per class.""" - return self.box.ap_class_index - - @property - def results_dict(self) -> Dict[str, float]: - """Return dictionary of computed performance metrics and statistics.""" - return dict(zip(self.keys + ["fitness"], self.mean_results() + [self.fitness])) - - @property - def curves(self) -> List[str]: - """Return a list of curves for accessing specific metrics curves.""" - return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"] - - @property - def curves_results(self) -> List[List]: - """Return a list of computed performance metrics and statistics.""" - return self.box.curves_results - - def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]: - """ - Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared - scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class. - - Args: - normalize (bool): For Detect metrics, everything is normalized by default [0-1]. - decimals (int): Number of decimal places to round the metrics values to. - - Returns: - (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values. - - Examples: - >>> results = model.val(data="coco8.yaml") - >>> detection_summary = results.summary() - >>> print(detection_summary) - """ - per_class = { - "Box-P": self.box.p, - "Box-R": self.box.r, - "Box-F1": self.box.f1, - } - return [ - { - "Class": self.names[self.ap_class_index[i]], - "Images": self.nt_per_image[self.ap_class_index[i]], - "Instances": self.nt_per_class[self.ap_class_index[i]], - **{k: round(v[i], decimals) for k, v in per_class.items()}, - "mAP50": round(self.class_result(i)[2], decimals), - "mAP50-95": round(self.class_result(i)[3], decimals), - } - for i in range(len(per_class["Box-P"])) - ] - - -class SegmentMetrics(DetMetrics): - """ - Calculate and aggregate detection and segmentation metrics over a given set of classes. - - Attributes: - names (Dict[int, str]): Dictionary of class names. - box (Metric): An instance of the Metric class for storing detection results. - seg (Metric): An instance of the Metric class to calculate mask segmentation metrics. - speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process. - task (str): The task type, set to 'segment'. - stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images. - nt_per_class: Number of targets per class. - nt_per_image: Number of targets per image. - - Methods: - process: Process the detection and segmentation metrics over the given set of predictions. - keys: Return a list of keys for accessing metrics. - mean_results: Return the mean metrics for bounding box and segmentation results. - class_result: Return classification results for a specified class index. - maps: Return mAP scores for object detection and semantic segmentation models. - fitness: Return the fitness score for both segmentation and bounding box models. - curves: Return a list of curves for accessing specific metrics curves. - curves_results: Provide a list of computed performance metrics and statistics. - summary: Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. - """ - - def __init__(self, names: Dict[int, str] = {}) -> None: - """ - Initialize a SegmentMetrics instance with a save directory, plot flag, and class names. - - Args: - names (Dict[int, str], optional): Dictionary of class names. - """ - DetMetrics.__init__(self, names) - self.seg = Metric() - self.task = "segment" - self.stats["tp_m"] = [] # add additional stats for masks - - def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]: - """ - Process the detection and segmentation metrics over the given set of predictions. - - Args: - save_dir (Path): Directory to save plots. Defaults to Path("."). - plot (bool): Whether to plot precision-recall curves. Defaults to False. - on_plot (callable, optional): Function to call after plots are generated. Defaults to None. - - Returns: - (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays. - """ - stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats - results_mask = ap_per_class( - stats["tp_m"], - stats["conf"], - stats["pred_cls"], - stats["target_cls"], - plot=plot, - on_plot=on_plot, - save_dir=save_dir, - names=self.names, - prefix="Mask", - )[2:] - self.seg.nc = len(self.names) - self.seg.update(results_mask) - return stats - - @property - def keys(self) -> List[str]: - """Return a list of keys for accessing metrics.""" - return DetMetrics.keys.fget(self) + [ - "metrics/precision(M)", - "metrics/recall(M)", - "metrics/mAP50(M)", - "metrics/mAP50-95(M)", - ] - - def mean_results(self) -> List[float]: - """Return the mean metrics for bounding box and segmentation results.""" - return DetMetrics.mean_results(self) + self.seg.mean_results() - - def class_result(self, i: int) -> List[float]: - """Return classification results for a specified class index.""" - return DetMetrics.class_result(self, i) + self.seg.class_result(i) - - @property - def maps(self) -> np.ndarray: - """Return mAP scores for object detection and semantic segmentation models.""" - return DetMetrics.maps.fget(self) + self.seg.maps - - @property - def fitness(self) -> float: - """Return the fitness score for both segmentation and bounding box models.""" - return self.seg.fitness() + DetMetrics.fitness.fget(self) - - @property - def curves(self) -> List[str]: - """Return a list of curves for accessing specific metrics curves.""" - return DetMetrics.curves.fget(self) + [ - "Precision-Recall(M)", - "F1-Confidence(M)", - "Precision-Confidence(M)", - "Recall-Confidence(M)", - ] - - @property - def curves_results(self) -> List[List]: - """Return a list of computed performance metrics and statistics.""" - return DetMetrics.curves_results.fget(self) + self.seg.curves_results - - def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]: - """ - Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both - box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class. - - Args: - normalize (bool): For Segment metrics, everything is normalized by default [0-1]. - decimals (int): Number of decimal places to round the metrics values to. - - Returns: - (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values. - - Examples: - >>> results = model.val(data="coco8-seg.yaml") - >>> seg_summary = results.summary(decimals=4) - >>> print(seg_summary) - """ - per_class = { - "Mask-P": self.seg.p, - "Mask-R": self.seg.r, - "Mask-F1": self.seg.f1, - } - summary = DetMetrics.summary(self, normalize, decimals) # get box summary - for i, s in enumerate(summary): - s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}}) - return summary - - -class PoseMetrics(DetMetrics): - """ - Calculate and aggregate detection and pose metrics over a given set of classes. - - Attributes: - names (Dict[int, str]): Dictionary of class names. - pose (Metric): An instance of the Metric class to calculate pose metrics. - box (Metric): An instance of the Metric class for storing detection results. - speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process. - task (str): The task type, set to 'pose'. - stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images. - nt_per_class: Number of targets per class. - nt_per_image: Number of targets per image. - - Methods: - process: Process the detection and pose metrics over the given set of predictions. R - keys: Return a list of keys for accessing metrics. - mean_results: Return the mean results of box and pose. - class_result: Return the class-wise detection results for a specific class i. - maps: Return the mean average precision (mAP) per class for both box and pose detections. - fitness: Return combined fitness score for pose and box detection. - curves: Return a list of curves for accessing specific metrics curves. - curves_results: Provide a list of computed performance metrics and statistics. - summary: Generate a summarized representation of per-class pose metrics as a list of dictionaries. - """ - - def __init__(self, names: Dict[int, str] = {}) -> None: - """ - Initialize the PoseMetrics class with directory path, class names, and plotting options. - - Args: - names (Dict[int, str], optional): Dictionary of class names. - """ - super().__init__(names) - self.pose = Metric() - self.task = "pose" - self.stats["tp_p"] = [] # add additional stats for pose - - def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> Dict[str, np.ndarray]: - """ - Process the detection and pose metrics over the given set of predictions. - - Args: - save_dir (Path): Directory to save plots. Defaults to Path("."). - plot (bool): Whether to plot precision-recall curves. Defaults to False. - on_plot (callable, optional): Function to call after plots are generated. - - Returns: - (Dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays. - """ - stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats - results_pose = ap_per_class( - stats["tp_p"], - stats["conf"], - stats["pred_cls"], - stats["target_cls"], - plot=plot, - on_plot=on_plot, - save_dir=save_dir, - names=self.names, - prefix="Pose", - )[2:] - self.pose.nc = len(self.names) - self.pose.update(results_pose) - return stats - - @property - def keys(self) -> List[str]: - """Return a list of evaluation metric keys.""" - return DetMetrics.keys.fget(self) + [ - "metrics/precision(P)", - "metrics/recall(P)", - "metrics/mAP50(P)", - "metrics/mAP50-95(P)", - ] - - def mean_results(self) -> List[float]: - """Return the mean results of box and pose.""" - return DetMetrics.mean_results(self) + self.pose.mean_results() - - def class_result(self, i: int) -> List[float]: - """Return the class-wise detection results for a specific class i.""" - return DetMetrics.class_result(self, i) + self.pose.class_result(i) - - @property - def maps(self) -> np.ndarray: - """Return the mean average precision (mAP) per class for both box and pose detections.""" - return DetMetrics.maps.fget(self) + self.pose.maps - - @property - def fitness(self) -> float: - """Return combined fitness score for pose and box detection.""" - return self.pose.fitness() + DetMetrics.fitness.fget(self) - - @property - def curves(self) -> List[str]: - """Return a list of curves for accessing specific metrics curves.""" - return DetMetrics.curves.fget(self) + [ - "Precision-Recall(B)", - "F1-Confidence(B)", - "Precision-Confidence(B)", - "Recall-Confidence(B)", - "Precision-Recall(P)", - "F1-Confidence(P)", - "Precision-Confidence(P)", - "Recall-Confidence(P)", - ] - - @property - def curves_results(self) -> List[List]: - """Return a list of computed performance metrics and statistics.""" - return DetMetrics.curves_results.fget(self) + self.pose.curves_results - - def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, Any]]: - """ - Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and - pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class. - - Args: - normalize (bool): For Pose metrics, everything is normalized by default [0-1]. - decimals (int): Number of decimal places to round the metrics values to. - - Returns: - (List[Dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values. - - Examples: - >>> results = model.val(data="coco8-pose.yaml") - >>> pose_summary = results.summary(decimals=4) - >>> print(pose_summary) - """ - per_class = { - "Pose-P": self.pose.p, - "Pose-R": self.pose.r, - "Pose-F1": self.pose.f1, - } - summary = DetMetrics.summary(self, normalize, decimals) # get box summary - for i, s in enumerate(summary): - s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}}) - return summary - - -class ClassifyMetrics(SimpleClass, DataExportMixin): - """ - Class for computing classification metrics including top-1 and top-5 accuracy. - - Attributes: - top1 (float): The top-1 accuracy. - top5 (float): The top-5 accuracy. - speed (dict): A dictionary containing the time taken for each step in the pipeline. - task (str): The task type, set to 'classify'. - - Methods: - process: Process target classes and predicted classes to compute metrics. - fitness: Return mean of top-1 and top-5 accuracies as fitness score. - results_dict: Return a dictionary with model's performance metrics and fitness score. - keys: Return a list of keys for the results_dict property. - curves: Return a list of curves for accessing specific metrics curves. - curves_results: Provide a list of computed performance metrics and statistics. - summary: Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy). - """ - - def __init__(self) -> None: - """Initialize a ClassifyMetrics instance.""" - self.top1 = 0 - self.top5 = 0 - self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0} - self.task = "classify" - - def process(self, targets: torch.Tensor, pred: torch.Tensor): - """ - Process target classes and predicted classes to compute metrics. - - Args: - targets (torch.Tensor): Target classes. - pred (torch.Tensor): Predicted classes. - """ - pred, targets = torch.cat(pred), torch.cat(targets) - correct = (targets[:, None] == pred).float() - acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy - self.top1, self.top5 = acc.mean(0).tolist() - - @property - def fitness(self) -> float: - """Return mean of top-1 and top-5 accuracies as fitness score.""" - return (self.top1 + self.top5) / 2 - - @property - def results_dict(self) -> Dict[str, float]: - """Return a dictionary with model's performance metrics and fitness score.""" - return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness])) - - @property - def keys(self) -> List[str]: - """Return a list of keys for the results_dict property.""" - return ["metrics/accuracy_top1", "metrics/accuracy_top5"] - - @property - def curves(self) -> List: - """Return a list of curves for accessing specific metrics curves.""" - return [] - - @property - def curves_results(self) -> List: - """Return a list of curves for accessing specific metrics curves.""" - return [] - - def summary(self, normalize: bool = True, decimals: int = 5) -> List[Dict[str, float]]: - """ - Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy). - - Args: - normalize (bool): For Classify metrics, everything is normalized by default [0-1]. - decimals (int): Number of decimal places to round the metrics values to. - - Returns: - (List[Dict[str, float]]): A list with one dictionary containing Top-1 and Top-5 classification accuracy. - - Examples: - >>> results = model.val(data="imagenet10") - >>> classify_summary = results.summary(decimals=4) - >>> print(classify_summary) - """ - return [{"top1_acc": round(self.top1, decimals), "top5_acc": round(self.top5, decimals)}] - - -class OBBMetrics(DetMetrics): - """ - Metrics for evaluating oriented bounding box (OBB) detection. - - Attributes: - names (Dict[int, str]): Dictionary of class names. - box (Metric): An instance of the Metric class for storing detection results. - speed (Dict[str, float]): A dictionary for storing execution times of different parts of the detection process. - task (str): The task type, set to 'obb'. - stats (Dict[str, List]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images. - nt_per_class: Number of targets per class. - nt_per_image: Number of targets per image. - - References: - https://arxiv.org/pdf/2106.06072.pdf - """ - - def __init__(self, names: Dict[int, str] = {}) -> None: - """ - Initialize an OBBMetrics instance with directory, plotting, and class names. - - Args: - names (Dict[int, str], optional): Dictionary of class names. - """ - DetMetrics.__init__(self, names) - # TODO: probably remove task as well - self.task = "obb" diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/ops.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/ops.py deleted file mode 100644 index 40d191d..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/ops.py +++ /dev/null @@ -1,888 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import contextlib -import math -import re -import time -from typing import Optional - -import cv2 -import numpy as np -import torch -import torch.nn.functional as F - -from ultralytics.utils import LOGGER -from ultralytics.utils.metrics import batch_probiou - - -class Profile(contextlib.ContextDecorator): - """ - Ultralytics Profile class for timing code execution. - - Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing - measurements with CUDA synchronization support for GPU operations. - - Attributes: - t (float): Accumulated time in seconds. - device (torch.device): Device used for model inference. - cuda (bool): Whether CUDA is being used for timing synchronization. - - Examples: - Use as a context manager to time code execution - >>> with Profile(device=device) as dt: - ... pass # slow operation here - >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s" - - Use as a decorator to time function execution - >>> @Profile() - ... def slow_function(): - ... time.sleep(0.1) - """ - - def __init__(self, t: float = 0.0, device: Optional[torch.device] = None): - """ - Initialize the Profile class. - - Args: - t (float): Initial accumulated time in seconds. - device (torch.device, optional): Device used for model inference to enable CUDA synchronization. - """ - self.t = t - self.device = device - self.cuda = bool(device and str(device).startswith("cuda")) - - def __enter__(self): - """Start timing.""" - self.start = self.time() - return self - - def __exit__(self, type, value, traceback): # noqa - """Stop timing.""" - self.dt = self.time() - self.start # delta-time - self.t += self.dt # accumulate dt - - def __str__(self): - """Return a human-readable string representing the accumulated elapsed time.""" - return f"Elapsed time is {self.t} s" - - def time(self): - """Get current time with CUDA synchronization if applicable.""" - if self.cuda: - torch.cuda.synchronize(self.device) - return time.perf_counter() - - -def segment2box(segment, width: int = 640, height: int = 640): - """ - Convert segment coordinates to bounding box coordinates. - - Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates. - Applies inside-image constraint and clips coordinates when necessary. - - Args: - segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points. - width (int): Width of the image in pixels. - height (int): Height of the image in pixels. - - Returns: - (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2]. - """ - x, y = segment.T # segment xy - # Clip coordinates if 3 out of 4 sides are outside the image - if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3: - x = x.clip(0, width) - y = y.clip(0, height) - inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height) - x = x[inside] - y = y[inside] - return ( - np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype) - if any(x) - else np.zeros(4, dtype=segment.dtype) - ) # xyxy - - -def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False): - """ - Rescale bounding boxes from one image shape to another. - - Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes. - Supports both xyxy and xywh box formats. - - Args: - img1_shape (tuple): Shape of the source image (height, width). - boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4). - img0_shape (tuple): Shape of the target image (height, width). - ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes. - padding (bool): Whether boxes are based on YOLO-style augmented images with padding. - xywh (bool): Whether box format is xywh (True) or xyxy (False). - - Returns: - (torch.Tensor): Rescaled bounding boxes in the same format as input. - """ - if ratio_pad is None: # calculate from img0_shape - gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new - pad = ( - round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1), - round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1), - ) # wh padding - else: - gain = ratio_pad[0][0] - pad = ratio_pad[1] - - if padding: - boxes[..., 0] -= pad[0] # x padding - boxes[..., 1] -= pad[1] # y padding - if not xywh: - boxes[..., 2] -= pad[0] # x padding - boxes[..., 3] -= pad[1] # y padding - boxes[..., :4] /= gain - return clip_boxes(boxes, img0_shape) - - -def make_divisible(x: int, divisor): - """ - Return the nearest number that is divisible by the given divisor. - - Args: - x (int): The number to make divisible. - divisor (int | torch.Tensor): The divisor. - - Returns: - (int): The nearest number divisible by the divisor. - """ - if isinstance(divisor, torch.Tensor): - divisor = int(divisor.max()) # to int - return math.ceil(x / divisor) * divisor - - -def nms_rotated(boxes, scores, threshold: float = 0.45, use_triu: bool = True): - """ - Perform NMS on oriented bounding boxes using probiou and fast-nms. - - Args: - boxes (torch.Tensor): Rotated bounding boxes with shape (N, 5) in xywhr format. - scores (torch.Tensor): Confidence scores with shape (N,). - threshold (float): IoU threshold for NMS. - use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations. - - Returns: - (torch.Tensor): Indices of boxes to keep after NMS. - """ - sorted_idx = torch.argsort(scores, descending=True) - boxes = boxes[sorted_idx] - ious = batch_probiou(boxes, boxes) - if use_triu: - ious = ious.triu_(diagonal=1) - # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition - pick = torch.nonzero((ious >= threshold).sum(0) <= 0).squeeze_(-1) - else: - n = boxes.shape[0] - row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n) - col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1) - upper_mask = row_idx < col_idx - ious = ious * upper_mask - # Zeroing these scores ensures the additional indices would not affect the final results - scores[~((ious >= threshold).sum(0) <= 0)] = 0 - # NOTE: return indices with fixed length to avoid TFLite reshape error - pick = torch.topk(scores, scores.shape[0]).indices - return sorted_idx[pick] - - -def non_max_suppression( - prediction, - conf_thres: float = 0.25, - iou_thres: float = 0.45, - classes=None, - agnostic: bool = False, - multi_label: bool = False, - labels=(), - max_det: int = 300, - nc: int = 0, # number of classes (optional) - max_time_img: float = 0.05, - max_nms: int = 30000, - max_wh: int = 7680, - in_place: bool = True, - rotated: bool = False, - end2end: bool = False, - return_idxs: bool = False, -): - """ - Perform non-maximum suppression (NMS) on prediction results. - - Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple - detection formats including standard boxes, rotated boxes, and masks. - - Args: - prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes) - containing boxes, classes, and optional masks. - conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0. - iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0. - classes (List[int], optional): List of class indices to consider. If None, all classes are considered. - agnostic (bool): Whether to perform class-agnostic NMS. - multi_label (bool): Whether each box can have multiple labels. - labels (List[List[Union[int, float, torch.Tensor]]]): A priori labels for each image. - max_det (int): Maximum number of detections to keep per image. - nc (int): Number of classes. Indices after this are considered masks. - max_time_img (float): Maximum time in seconds for processing one image. - max_nms (int): Maximum number of boxes for torchvision.ops.nms(). - max_wh (int): Maximum box width and height in pixels. - in_place (bool): Whether to modify the input prediction tensor in place. - rotated (bool): Whether to handle Oriented Bounding Boxes (OBB). - end2end (bool): Whether the model is end-to-end and doesn't require NMS. - return_idxs (bool): Whether to return the indices of kept detections. - - Returns: - output (List[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks) - containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...). - keepi (List[torch.Tensor]): Indices of kept detections if return_idxs=True. - """ - import torchvision # scope for faster 'import ultralytics' - - # Checks - assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" - assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" - if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out) - prediction = prediction[0] # select only inference output - if classes is not None: - classes = torch.tensor(classes, device=prediction.device) - - if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6) - output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction] - if classes is not None: - output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output] - return output - - bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300) - nc = nc or (prediction.shape[1] - 4) # number of classes - extra = prediction.shape[1] - nc - 4 # number of extra info - mi = 4 + nc # mask start index - xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates - xinds = torch.stack([torch.arange(len(i), device=prediction.device) for i in xc])[..., None] # to track idxs - - # Settings - # min_wh = 2 # (pixels) minimum box width and height - time_limit = 2.0 + max_time_img * bs # seconds to quit after - multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) - - prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84) - if not rotated: - if in_place: - prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy - else: - prediction = torch.cat((xywh2xyxy(prediction[..., :4]), prediction[..., 4:]), dim=-1) # xywh to xyxy - - t = time.time() - output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs - keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs - for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices) - # Apply constraints - # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height - filt = xc[xi] # confidence - x, xk = x[filt], xk[filt] - - # Cat apriori labels if autolabelling - if labels and len(labels[xi]) and not rotated: - lb = labels[xi] - v = torch.zeros((len(lb), nc + extra + 4), device=x.device) - v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box - v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls - x = torch.cat((x, v), 0) - - # If none remain process next image - if not x.shape[0]: - continue - - # Detections matrix nx6 (xyxy, conf, cls) - box, cls, mask = x.split((4, nc, extra), 1) - - if multi_label: - i, j = torch.where(cls > conf_thres) - x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1) - xk = xk[i] - else: # best class only - conf, j = cls.max(1, keepdim=True) - filt = conf.view(-1) > conf_thres - x = torch.cat((box, conf, j.float(), mask), 1)[filt] - xk = xk[filt] - - # Filter by class - if classes is not None: - filt = (x[:, 5:6] == classes).any(1) - x, xk = x[filt], xk[filt] - - # Check shape - n = x.shape[0] # number of boxes - if not n: # no boxes - continue - if n > max_nms: # excess boxes - filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes - x, xk = x[filt], xk[filt] - - # Batched NMS - c = x[:, 5:6] * (0 if agnostic else max_wh) # classes - scores = x[:, 4] # scores - if rotated: - boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr - i = nms_rotated(boxes, scores, iou_thres) - else: - boxes = x[:, :4] + c # boxes (offset by class) - i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS - i = i[:max_det] # limit detections - - output[xi], keepi[xi] = x[i], xk[i].reshape(-1) - if (time.time() - t) > time_limit: - LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded") - break # time limit exceeded - - return (output, keepi) if return_idxs else output - - -def clip_boxes(boxes, shape): - """ - Clip bounding boxes to image boundaries. - - Args: - boxes (torch.Tensor | np.ndarray): Bounding boxes to clip. - shape (tuple): Image shape as (height, width). - - Returns: - (torch.Tensor | np.ndarray): Clipped bounding boxes. - """ - if isinstance(boxes, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) - boxes[..., 0] = boxes[..., 0].clamp(0, shape[1]) # x1 - boxes[..., 1] = boxes[..., 1].clamp(0, shape[0]) # y1 - boxes[..., 2] = boxes[..., 2].clamp(0, shape[1]) # x2 - boxes[..., 3] = boxes[..., 3].clamp(0, shape[0]) # y2 - else: # np.array (faster grouped) - boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, shape[1]) # x1, x2 - boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, shape[0]) # y1, y2 - return boxes - - -def clip_coords(coords, shape): - """ - Clip line coordinates to image boundaries. - - Args: - coords (torch.Tensor | np.ndarray): Line coordinates to clip. - shape (tuple): Image shape as (height, width). - - Returns: - (torch.Tensor | np.ndarray): Clipped coordinates. - """ - if isinstance(coords, torch.Tensor): # faster individually (WARNING: inplace .clamp_() Apple MPS bug) - coords[..., 0] = coords[..., 0].clamp(0, shape[1]) # x - coords[..., 1] = coords[..., 1].clamp(0, shape[0]) # y - else: # np.array (faster grouped) - coords[..., 0] = coords[..., 0].clip(0, shape[1]) # x - coords[..., 1] = coords[..., 1].clip(0, shape[0]) # y - return coords - - -def scale_image(masks, im0_shape, ratio_pad=None): - """ - Rescale masks to original image size. - - Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding - that was applied during preprocessing. - - Args: - masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3]. - im0_shape (tuple): Original image shape as (height, width). - ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)). - - Returns: - (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions. - """ - # Rescale coordinates (xyxy) from im1_shape to im0_shape - im1_shape = masks.shape - if im1_shape[:2] == im0_shape[:2]: - return masks - if ratio_pad is None: # calculate from im0_shape - gain = min(im1_shape[0] / im0_shape[0], im1_shape[1] / im0_shape[1]) # gain = old / new - pad = (im1_shape[1] - im0_shape[1] * gain) / 2, (im1_shape[0] - im0_shape[0] * gain) / 2 # wh padding - else: - pad = ratio_pad[1] - - top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) - bottom, right = ( - im1_shape[0] - int(round(pad[1] + 0.1)), - im1_shape[1] - int(round(pad[0] + 0.1)), - ) - - if len(masks.shape) < 2: - raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}') - masks = masks[top:bottom, left:right] - masks = cv2.resize(masks, (im0_shape[1], im0_shape[0])) - if len(masks.shape) == 2: - masks = masks[:, :, None] - - return masks - - -def xyxy2xywh(x): - """ - Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the - top-left corner and (x2, y2) is the bottom-right corner. - - Args: - x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format. - - Returns: - (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format. - """ - assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" - y = empty_like(x) # faster than clone/copy - y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center - y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center - y[..., 2] = x[..., 2] - x[..., 0] # width - y[..., 3] = x[..., 3] - x[..., 1] # height - return y - - -def xywh2xyxy(x): - """ - Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the - top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel. - - Args: - x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format. - - Returns: - (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format. - """ - assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" - y = empty_like(x) # faster than clone/copy - xy = x[..., :2] # centers - wh = x[..., 2:] / 2 # half width-height - y[..., :2] = xy - wh # top left xy - y[..., 2:] = xy + wh # bottom right xy - return y - - -def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0): - """ - Convert normalized bounding box coordinates to pixel coordinates. - - Args: - x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format. - w (int): Image width in pixels. - h (int): Image height in pixels. - padw (int): Padding width in pixels. - padh (int): Padding height in pixels. - - Returns: - y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where - x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box. - """ - assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" - y = empty_like(x) # faster than clone/copy - y[..., 0] = w * (x[..., 0] - x[..., 2] / 2) + padw # top left x - y[..., 1] = h * (x[..., 1] - x[..., 3] / 2) + padh # top left y - y[..., 2] = w * (x[..., 0] + x[..., 2] / 2) + padw # bottom right x - y[..., 3] = h * (x[..., 1] + x[..., 3] / 2) + padh # bottom right y - return y - - -def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0): - """ - Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y, - width and height are normalized to image dimensions. - - Args: - x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format. - w (int): Image width in pixels. - h (int): Image height in pixels. - clip (bool): Whether to clip boxes to image boundaries. - eps (float): Minimum value for box width and height. - - Returns: - (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format. - """ - if clip: - x = clip_boxes(x, (h - eps, w - eps)) - assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}" - y = empty_like(x) # faster than clone/copy - y[..., 0] = ((x[..., 0] + x[..., 2]) / 2) / w # x center - y[..., 1] = ((x[..., 1] + x[..., 3]) / 2) / h # y center - y[..., 2] = (x[..., 2] - x[..., 0]) / w # width - y[..., 3] = (x[..., 3] - x[..., 1]) / h # height - return y - - -def xywh2ltwh(x): - """ - Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates. - - Args: - x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format. - - Returns: - (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format. - """ - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x - y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y - return y - - -def xyxy2ltwh(x): - """ - Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format. - - Args: - x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format. - - Returns: - (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format. - """ - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[..., 2] = x[..., 2] - x[..., 0] # width - y[..., 3] = x[..., 3] - x[..., 1] # height - return y - - -def ltwh2xywh(x): - """ - Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center. - - Args: - x (torch.Tensor): Input bounding box coordinates. - - Returns: - (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format. - """ - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x - y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y - return y - - -def xyxyxyxy2xywhr(x): - """ - Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format. - - Args: - x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format. - - Returns: - (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5). - Rotation values are in radians from 0 to pi/2. - """ - is_torch = isinstance(x, torch.Tensor) - points = x.cpu().numpy() if is_torch else x - points = points.reshape(len(x), -1, 2) - rboxes = [] - for pts in points: - # NOTE: Use cv2.minAreaRect to get accurate xywhr, - # especially some objects are cut off by augmentations in dataloader. - (cx, cy), (w, h), angle = cv2.minAreaRect(pts) - rboxes.append([cx, cy, w, h, angle / 180 * np.pi]) - return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes) - - -def xywhr2xyxyxyxy(x): - """ - Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format. - - Args: - x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5). - Rotation values should be in radians from 0 to pi/2. - - Returns: - (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2). - """ - cos, sin, cat, stack = ( - (torch.cos, torch.sin, torch.cat, torch.stack) - if isinstance(x, torch.Tensor) - else (np.cos, np.sin, np.concatenate, np.stack) - ) - - ctr = x[..., :2] - w, h, angle = (x[..., i : i + 1] for i in range(2, 5)) - cos_value, sin_value = cos(angle), sin(angle) - vec1 = [w / 2 * cos_value, w / 2 * sin_value] - vec2 = [-h / 2 * sin_value, h / 2 * cos_value] - vec1 = cat(vec1, -1) - vec2 = cat(vec2, -1) - pt1 = ctr + vec1 + vec2 - pt2 = ctr + vec1 - vec2 - pt3 = ctr - vec1 - vec2 - pt4 = ctr - vec1 + vec2 - return stack([pt1, pt2, pt3, pt4], -2) - - -def ltwh2xyxy(x): - """ - Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right. - - Args: - x (np.ndarray | torch.Tensor): Input bounding box coordinates. - - Returns: - (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format. - """ - y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) - y[..., 2] = x[..., 2] + x[..., 0] # width - y[..., 3] = x[..., 3] + x[..., 1] # height - return y - - -def segments2boxes(segments): - """ - Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh). - - Args: - segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates. - - Returns: - (np.ndarray): Bounding box coordinates in xywh format. - """ - boxes = [] - for s in segments: - x, y = s.T # segment xy - boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy - return xyxy2xywh(np.array(boxes)) # cls, xywh - - -def resample_segments(segments, n: int = 1000): - """ - Resample segments to n points each using linear interpolation. - - Args: - segments (list): List of (N, 2) arrays where N is the number of points in each segment. - n (int): Number of points to resample each segment to. - - Returns: - (list): Resampled segments with n points each. - """ - for i, s in enumerate(segments): - if len(s) == n: - continue - s = np.concatenate((s, s[0:1, :]), axis=0) - x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n) - xp = np.arange(len(s)) - x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x - segments[i] = ( - np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T - ) # segment xy - return segments - - -def crop_mask(masks, boxes): - """ - Crop masks to bounding box regions. - - Args: - masks (torch.Tensor): Masks with shape (N, H, W). - boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form. - - Returns: - (torch.Tensor): Cropped masks. - """ - _, h, w = masks.shape - x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1) - r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w) - c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1) - - return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) - - -def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False): - """ - Apply masks to bounding boxes using mask head output. - - Args: - protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w). - masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS. - bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS. - shape (tuple): Input image size as (height, width). - upsample (bool): Whether to upsample masks to original image size. - - Returns: - (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w - are the height and width of the input image. The mask is applied to the bounding boxes. - """ - c, mh, mw = protos.shape # CHW - ih, iw = shape - masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW - width_ratio = mw / iw - height_ratio = mh / ih - - downsampled_bboxes = bboxes.clone() - downsampled_bboxes[:, 0] *= width_ratio - downsampled_bboxes[:, 2] *= width_ratio - downsampled_bboxes[:, 3] *= height_ratio - downsampled_bboxes[:, 1] *= height_ratio - - masks = crop_mask(masks, downsampled_bboxes) # CHW - if upsample: - masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW - return masks.gt_(0.0) - - -def process_mask_native(protos, masks_in, bboxes, shape): - """ - Apply masks to bounding boxes using mask head output with native upsampling. - - Args: - protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w). - masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS. - bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS. - shape (tuple): Input image size as (height, width). - - Returns: - (torch.Tensor): Binary mask tensor with shape (H, W, N). - """ - c, mh, mw = protos.shape # CHW - masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) - masks = scale_masks(masks[None], shape)[0] # CHW - masks = crop_mask(masks, bboxes) # CHW - return masks.gt_(0.0) - - -def scale_masks(masks, shape, padding: bool = True): - """ - Rescale segment masks to target shape. - - Args: - masks (torch.Tensor): Masks with shape (N, C, H, W). - shape (tuple): Target height and width as (height, width). - padding (bool): Whether masks are based on YOLO-style augmented images with padding. - - Returns: - (torch.Tensor): Rescaled masks. - """ - mh, mw = masks.shape[2:] - gain = min(mh / shape[0], mw / shape[1]) # gain = old / new - pad = [mw - shape[1] * gain, mh - shape[0] * gain] # wh padding - if padding: - pad[0] /= 2 - pad[1] /= 2 - top, left = (int(round(pad[1] - 0.1)), int(round(pad[0] - 0.1))) if padding else (0, 0) # y, x - bottom, right = ( - mh - int(round(pad[1] + 0.1)), - mw - int(round(pad[0] + 0.1)), - ) - masks = masks[..., top:bottom, left:right] - - masks = F.interpolate(masks, shape, mode="bilinear", align_corners=False) # NCHW - return masks - - -def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True): - """ - Rescale segment coordinates from img1_shape to img0_shape. - - Args: - img1_shape (tuple): Shape of the source image. - coords (torch.Tensor): Coordinates to scale with shape (N, 2). - img0_shape (tuple): Shape of the target image. - ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)). - normalize (bool): Whether to normalize coordinates to range [0, 1]. - padding (bool): Whether coordinates are based on YOLO-style augmented images with padding. - - Returns: - (torch.Tensor): Scaled coordinates. - """ - if ratio_pad is None: # calculate from img0_shape - gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new - pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding - else: - gain = ratio_pad[0][0] - pad = ratio_pad[1] - - if padding: - coords[..., 0] -= pad[0] # x padding - coords[..., 1] -= pad[1] # y padding - coords[..., 0] /= gain - coords[..., 1] /= gain - coords = clip_coords(coords, img0_shape) - if normalize: - coords[..., 0] /= img0_shape[1] # width - coords[..., 1] /= img0_shape[0] # height - return coords - - -def regularize_rboxes(rboxes): - """ - Regularize rotated bounding boxes to range [0, pi/2]. - - Args: - rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format. - - Returns: - (torch.Tensor): Regularized rotated boxes. - """ - x, y, w, h, t = rboxes.unbind(dim=-1) - # Swap edge if t >= pi/2 while not being symmetrically opposite - swap = t % math.pi >= math.pi / 2 - w_ = torch.where(swap, h, w) - h_ = torch.where(swap, w, h) - t = t % (math.pi / 2) - return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes - - -def masks2segments(masks, strategy: str = "all"): - """ - Convert masks to segments using contour detection. - - Args: - masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160). - strategy (str): Segmentation strategy, either 'all' or 'largest'. - - Returns: - (list): List of segment masks as float32 arrays. - """ - from ultralytics.data.converter import merge_multi_segment - - segments = [] - for x in masks.int().cpu().numpy().astype("uint8"): - c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] - if c: - if strategy == "all": # merge and concatenate all segments - c = ( - np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c])) - if len(c) > 1 - else c[0].reshape(-1, 2) - ) - elif strategy == "largest": # select largest segment - c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2) - else: - c = np.zeros((0, 2)) # no segments found - segments.append(c.astype("float32")) - return segments - - -def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray: - """ - Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout. - - Args: - batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32. - - Returns: - (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8. - """ - return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy() - - -def clean_str(s): - """ - Clean a string by replacing special characters with '_' character. - - Args: - s (str): A string needing special characters replaced. - - Returns: - (str): A string with special characters replaced by an underscore _. - """ - return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s) - - -def empty_like(x): - """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype.""" - return ( - torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32) - ) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/patches.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/patches.py deleted file mode 100644 index 25775f8..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/patches.py +++ /dev/null @@ -1,187 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -"""Monkey patches to update/extend functionality of existing functions.""" - -import time -from contextlib import contextmanager -from copy import copy -from pathlib import Path -from typing import Any, Dict, List, Optional - -import cv2 -import numpy as np -import torch - -# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------ -_imshow = cv2.imshow # copy to avoid recursion errors - - -def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> Optional[np.ndarray]: - """ - Read an image from a file with multilanguage filename support. - - Args: - filename (str): Path to the file to read. - flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read. - - Returns: - (np.ndarray | None): The read image array, or None if reading fails. - - Examples: - >>> img = imread("path/to/image.jpg") - >>> img = imread("path/to/image.jpg", cv2.IMREAD_GRAYSCALE) - """ - file_bytes = np.fromfile(filename, np.uint8) - if filename.endswith((".tiff", ".tif")): - success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED) - if success: - # Handle RGB images in tif/tiff format - return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2) - return None - else: - im = cv2.imdecode(file_bytes, flags) - return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions - - -def imwrite(filename: str, img: np.ndarray, params: Optional[List[int]] = None) -> bool: - """ - Write an image to a file with multilanguage filename support. - - Args: - filename (str): Path to the file to write. - img (np.ndarray): Image to write. - params (List[int], optional): Additional parameters for image encoding. - - Returns: - (bool): True if the file was written successfully, False otherwise. - - Examples: - >>> import numpy as np - >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image - >>> success = imwrite("output.jpg", img) # Write image to file - >>> print(success) - True - """ - try: - cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename) - return True - except Exception: - return False - - -def imshow(winname: str, mat: np.ndarray) -> None: - """ - Display an image in the specified window with multilanguage window name support. - - This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles - multilanguage window names by encoding them properly for OpenCV compatibility. - - Args: - winname (str): Name of the window where the image will be displayed. If a window with this name already - exists, the image will be displayed in that window. - mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image. - - Examples: - >>> import numpy as np - >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image - >>> img[:100, :100] = [255, 0, 0] # Add a blue square - >>> imshow("Example Window", img) # Display the image - """ - _imshow(winname.encode("unicode_escape").decode(), mat) - - -# PyTorch functions ---------------------------------------------------------------------------------------------------- -_torch_save = torch.save - - -def torch_load(*args, **kwargs): - """ - Load a PyTorch model with updated arguments to avoid warnings. - - This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings. - - Args: - *args (Any): Variable length argument list to pass to torch.load. - **kwargs (Any): Arbitrary keyword arguments to pass to torch.load. - - Returns: - (Any): The loaded PyTorch object. - - Notes: - For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False' - if the argument is not provided, to avoid deprecation warnings. - """ - from ultralytics.utils.torch_utils import TORCH_1_13 - - if TORCH_1_13 and "weights_only" not in kwargs: - kwargs["weights_only"] = False - - return torch.load(*args, **kwargs) - - -def torch_save(*args, **kwargs): - """ - Save PyTorch objects with retry mechanism for robustness. - - This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur - due to device flushing delays or antivirus scanning. - - Args: - *args (Any): Positional arguments to pass to torch.save. - **kwargs (Any): Keyword arguments to pass to torch.save. - - Examples: - >>> model = torch.nn.Linear(10, 1) - >>> torch_save(model.state_dict(), "model.pt") - """ - for i in range(4): # 3 retries - try: - return _torch_save(*args, **kwargs) - except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan - if i == 3: - raise e - time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s - - -@contextmanager -def arange_patch(args): - """ - Workaround for ONNX torch.arange incompatibility with FP16. - - https://github.com/pytorch/pytorch/issues/148041. - """ - if args.dynamic and args.half and args.format == "onnx": - func = torch.arange - - def arange(*args, dtype=None, **kwargs): - """Return a 1-D tensor of size with values from the interval and common difference.""" - return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype - - torch.arange = arange # patch - yield - torch.arange = func # unpatch - else: - yield - - -@contextmanager -def override_configs(args, overrides: Optional[Dict[str, Any]] = None): - """ - Context manager to temporarily override configurations in args. - - Args: - args (IterableSimpleNamespace): Original configuration arguments. - overrides (Dict[str, Any]): Dictionary of overrides to apply. - - Yields: - (IterableSimpleNamespace): Configuration arguments with overrides applied. - """ - if overrides: - original_args = copy(args) - for key, value in overrides.items(): - setattr(args, key, value) - try: - yield args - finally: - args.__dict__.update(original_args.__dict__) - else: - yield args diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/plotting.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/plotting.py deleted file mode 100644 index df96f4e..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/plotting.py +++ /dev/null @@ -1,1037 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import math -import warnings -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Union - -import cv2 -import numpy as np -import torch -from PIL import Image, ImageDraw, ImageFont -from PIL import __version__ as pil_version - -from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded -from ultralytics.utils.checks import check_font, check_version, is_ascii -from ultralytics.utils.files import increment_path - - -class Colors: - """ - Ultralytics color palette for visualization and plotting. - - This class provides methods to work with the Ultralytics color palette, including converting hex color codes to - RGB values and accessing predefined color schemes for object detection and pose estimation. - - Attributes: - palette (List[tuple]): List of RGB color tuples for general use. - n (int): The number of colors in the palette. - pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8. - - Examples: - >>> from ultralytics.utils.plotting import Colors - >>> colors = Colors() - >>> colors(5, True) # Returns BGR format: (221, 111, 255) - >>> colors(5, False) # Returns RGB format: (255, 111, 221) - - ## Ultralytics Color Palette - - | Index | Color | HEX | RGB | - |-------|-------------------------------------------------------------------|-----------|-------------------| - | 0 | | `#042aff` | (4, 42, 255) | - | 1 | | `#0bdbeb` | (11, 219, 235) | - | 2 | | `#f3f3f3` | (243, 243, 243) | - | 3 | | `#00dfb7` | (0, 223, 183) | - | 4 | | `#111f68` | (17, 31, 104) | - | 5 | | `#ff6fdd` | (255, 111, 221) | - | 6 | | `#ff444f` | (255, 68, 79) | - | 7 | | `#cced00` | (204, 237, 0) | - | 8 | | `#00f344` | (0, 243, 68) | - | 9 | | `#bd00ff` | (189, 0, 255) | - | 10 | | `#00b4ff` | (0, 180, 255) | - | 11 | | `#dd00ba` | (221, 0, 186) | - | 12 | | `#00ffff` | (0, 255, 255) | - | 13 | | `#26c000` | (38, 192, 0) | - | 14 | | `#01ffb3` | (1, 255, 179) | - | 15 | | `#7d24ff` | (125, 36, 255) | - | 16 | | `#7b0068` | (123, 0, 104) | - | 17 | | `#ff1b6c` | (255, 27, 108) | - | 18 | | `#fc6d2f` | (252, 109, 47) | - | 19 | | `#a2ff0b` | (162, 255, 11) | - - ## Pose Color Palette - - | Index | Color | HEX | RGB | - |-------|-------------------------------------------------------------------|-----------|-------------------| - | 0 | | `#ff8000` | (255, 128, 0) | - | 1 | | `#ff9933` | (255, 153, 51) | - | 2 | | `#ffb266` | (255, 178, 102) | - | 3 | | `#e6e600` | (230, 230, 0) | - | 4 | | `#ff99ff` | (255, 153, 255) | - | 5 | | `#99ccff` | (153, 204, 255) | - | 6 | | `#ff66ff` | (255, 102, 255) | - | 7 | | `#ff33ff` | (255, 51, 255) | - | 8 | | `#66b2ff` | (102, 178, 255) | - | 9 | | `#3399ff` | (51, 153, 255) | - | 10 | | `#ff9999` | (255, 153, 153) | - | 11 | | `#ff6666` | (255, 102, 102) | - | 12 | | `#ff3333` | (255, 51, 51) | - | 13 | | `#99ff99` | (153, 255, 153) | - | 14 | | `#66ff66` | (102, 255, 102) | - | 15 | | `#33ff33` | (51, 255, 51) | - | 16 | | `#00ff00` | (0, 255, 0) | - | 17 | | `#0000ff` | (0, 0, 255) | - | 18 | | `#ff0000` | (255, 0, 0) | - | 19 | | `#ffffff` | (255, 255, 255) | - - !!! note "Ultralytics Brand Colors" - - For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand). - Please use the official Ultralytics colors for all marketing materials. - """ - - def __init__(self): - """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values().""" - hexs = ( - "042AFF", - "0BDBEB", - "F3F3F3", - "00DFB7", - "111F68", - "FF6FDD", - "FF444F", - "CCED00", - "00F344", - "BD00FF", - "00B4FF", - "DD00BA", - "00FFFF", - "26C000", - "01FFB3", - "7D24FF", - "7B0068", - "FF1B6C", - "FC6D2F", - "A2FF0B", - ) - self.palette = [self.hex2rgb(f"#{c}") for c in hexs] - self.n = len(self.palette) - self.pose_palette = np.array( - [ - [255, 128, 0], - [255, 153, 51], - [255, 178, 102], - [230, 230, 0], - [255, 153, 255], - [153, 204, 255], - [255, 102, 255], - [255, 51, 255], - [102, 178, 255], - [51, 153, 255], - [255, 153, 153], - [255, 102, 102], - [255, 51, 51], - [153, 255, 153], - [102, 255, 102], - [51, 255, 51], - [0, 255, 0], - [0, 0, 255], - [255, 0, 0], - [255, 255, 255], - ], - dtype=np.uint8, - ) - - def __call__(self, i: int, bgr: bool = False) -> tuple: - """ - Convert hex color codes to RGB values. - - Args: - i (int): Color index. - bgr (bool, optional): Whether to return BGR format instead of RGB. - - Returns: - (tuple): RGB or BGR color tuple. - """ - c = self.palette[int(i) % self.n] - return (c[2], c[1], c[0]) if bgr else c - - @staticmethod - def hex2rgb(h: str) -> tuple: - """Convert hex color codes to RGB values (i.e. default PIL order).""" - return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) - - -colors = Colors() # create instance for 'from utils.plots import colors' - - -class Annotator: - """ - Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations. - - Attributes: - im (Image.Image | np.ndarray): The image to annotate. - pil (bool): Whether to use PIL or cv2 for drawing annotations. - font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations. - lw (float): Line width for drawing. - skeleton (List[List[int]]): Skeleton structure for keypoints. - limb_color (List[int]): Color palette for limbs. - kpt_color (List[int]): Color palette for keypoints. - dark_colors (set): Set of colors considered dark for text contrast. - light_colors (set): Set of colors considered light for text contrast. - - Examples: - >>> from ultralytics.utils.plotting import Annotator - >>> im0 = cv2.imread("test.png") - >>> annotator = Annotator(im0, line_width=10) - >>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0)) - """ - - def __init__( - self, - im, - line_width: Optional[int] = None, - font_size: Optional[int] = None, - font: str = "Arial.ttf", - pil: bool = False, - example: str = "abc", - ): - """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs.""" - non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic - input_is_pil = isinstance(im, Image.Image) - self.pil = pil or non_ascii or input_is_pil - self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2) - if not input_is_pil: - if im.shape[2] == 1: # handle grayscale - im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) - elif im.shape[2] > 3: # multispectral - im = np.ascontiguousarray(im[..., :3]) - if self.pil: # use PIL - self.im = im if input_is_pil else Image.fromarray(im) - if self.im.mode not in {"RGB", "RGBA"}: # multispectral - self.im = self.im.convert("RGB") - self.draw = ImageDraw.Draw(self.im, "RGBA") - try: - font = check_font("Arial.Unicode.ttf" if non_ascii else font) - size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12) - self.font = ImageFont.truetype(str(font), size) - except Exception: - self.font = ImageFont.load_default() - # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string) - if check_version(pil_version, "9.2.0"): - self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height - else: # use cv2 - assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images." - self.im = im if im.flags.writeable else im.copy() - self.tf = max(self.lw - 1, 1) # font thickness - self.sf = self.lw / 3 # font scale - # Pose - self.skeleton = [ - [16, 14], - [14, 12], - [17, 15], - [15, 13], - [12, 13], - [6, 12], - [7, 13], - [6, 7], - [6, 8], - [7, 9], - [8, 10], - [9, 11], - [2, 3], - [1, 2], - [1, 3], - [2, 4], - [3, 5], - [4, 6], - [5, 7], - ] - - self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]] - self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]] - self.dark_colors = { - (235, 219, 11), - (243, 243, 243), - (183, 223, 0), - (221, 111, 255), - (0, 237, 204), - (68, 243, 0), - (255, 255, 0), - (179, 255, 1), - (11, 255, 162), - } - self.light_colors = { - (255, 42, 4), - (79, 68, 255), - (255, 0, 189), - (255, 180, 0), - (186, 0, 221), - (0, 192, 38), - (255, 36, 125), - (104, 0, 123), - (108, 27, 255), - (47, 109, 252), - (104, 31, 17), - } - - def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple: - """ - Assign text color based on background color. - - Args: - color (tuple, optional): The background color of the rectangle for text (B, G, R). - txt_color (tuple, optional): The color of the text (R, G, B). - - Returns: - (tuple): Text color for label. - - Examples: - >>> from ultralytics.utils.plotting import Annotator - >>> im0 = cv2.imread("test.png") - >>> annotator = Annotator(im0, line_width=10) - >>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255) - """ - if color in self.dark_colors: - return 104, 31, 17 - elif color in self.light_colors: - return 255, 255, 255 - else: - return txt_color - - def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)): - """ - Draw a bounding box on an image with a given label. - - Args: - box (tuple): The bounding box coordinates (x1, y1, x2, y2). - label (str, optional): The text label to be displayed. - color (tuple, optional): The background color of the rectangle (B, G, R). - txt_color (tuple, optional): The color of the text (R, G, B). - - Examples: - >>> from ultralytics.utils.plotting import Annotator - >>> im0 = cv2.imread("test.png") - >>> annotator = Annotator(im0, line_width=10) - >>> annotator.box_label(box=[10, 20, 30, 40], label="person") - """ - txt_color = self.get_txt_color(color, txt_color) - if isinstance(box, torch.Tensor): - box = box.tolist() - - multi_points = isinstance(box[0], list) # multiple points with shape (n, 2) - p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1])) - if self.pil: - self.draw.polygon( - [tuple(b) for b in box], width=self.lw, outline=color - ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color) - if label: - w, h = self.font.getsize(label) # text width, height - outside = p1[1] >= h # label fits outside box - if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image - p1 = self.im.size[0] - w, p1[1] - self.draw.rectangle( - (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1), - fill=color, - ) - # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0 - self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font) - else: # cv2 - cv2.polylines( - self.im, [np.asarray(box, dtype=int)], True, color, self.lw - ) if multi_points else cv2.rectangle( - self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA - ) - if label: - w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height - h += 3 # add pixels to pad text - outside = p1[1] >= h # label fits outside box - if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image - p1 = self.im.shape[1] - w, p1[1] - p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h - cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled - cv2.putText( - self.im, - label, - (p1[0], p1[1] - 2 if outside else p1[1] + h - 1), - 0, - self.sf, - txt_color, - thickness=self.tf, - lineType=cv2.LINE_AA, - ) - - def masks(self, masks, colors, im_gpu, alpha: float = 0.5, retina_masks: bool = False): - """ - Plot masks on image. - - Args: - masks (torch.Tensor): Predicted masks on cuda, shape: [n, h, w] - colors (List[List[int]]): Colors for predicted masks, [[r, g, b] * n] - im_gpu (torch.Tensor): Image is in cuda, shape: [3, h, w], range: [0, 1] - alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque. - retina_masks (bool, optional): Whether to use high resolution masks or not. - """ - if self.pil: - # Convert to numpy first - self.im = np.asarray(self.im).copy() - if len(masks) == 0: - self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255 - if im_gpu.device != masks.device: - im_gpu = im_gpu.to(masks.device) - colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3) - colors = colors[:, None, None] # shape(n,1,1,3) - masks = masks.unsqueeze(3) # shape(n,h,w,1) - masks_color = masks * (colors * alpha) # shape(n,h,w,3) - - inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1) - mcs = masks_color.max(dim=0).values # shape(n,h,w,3) - - im_gpu = im_gpu.flip(dims=[0]) # flip channel - im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3) - im_gpu = im_gpu * inv_alpha_masks[-1] + mcs - im_mask = im_gpu * 255 - im_mask_np = im_mask.byte().cpu().numpy() - self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape) - if self.pil: - # Convert im back to PIL and update draw - self.fromarray(self.im) - - def kpts( - self, - kpts, - shape: tuple = (640, 640), - radius: Optional[int] = None, - kpt_line: bool = True, - conf_thres: float = 0.25, - kpt_color: Optional[tuple] = None, - ): - """ - Plot keypoints on the image. - - Args: - kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence). - shape (tuple, optional): Image shape (h, w). - radius (int, optional): Keypoint radius. - kpt_line (bool, optional): Draw lines between keypoints. - conf_thres (float, optional): Confidence threshold. - kpt_color (tuple, optional): Keypoint color (B, G, R). - - Note: - - `kpt_line=True` currently only supports human pose plotting. - - Modifies self.im in-place. - - If self.pil is True, converts image to numpy array and back to PIL. - """ - radius = radius if radius is not None else self.lw - if self.pil: - # Convert to numpy first - self.im = np.asarray(self.im).copy() - nkpt, ndim = kpts.shape - is_pose = nkpt == 17 and ndim in {2, 3} - kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting - for i, k in enumerate(kpts): - color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i)) - x_coord, y_coord = k[0], k[1] - if x_coord % shape[1] != 0 and y_coord % shape[0] != 0: - if len(k) == 3: - conf = k[2] - if conf < conf_thres: - continue - cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA) - - if kpt_line: - ndim = kpts.shape[-1] - for i, sk in enumerate(self.skeleton): - pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1])) - pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1])) - if ndim == 3: - conf1 = kpts[(sk[0] - 1), 2] - conf2 = kpts[(sk[1] - 1), 2] - if conf1 < conf_thres or conf2 < conf_thres: - continue - if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0: - continue - if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0: - continue - cv2.line( - self.im, - pos1, - pos2, - kpt_color or self.limb_color[i].tolist(), - thickness=int(np.ceil(self.lw / 2)), - lineType=cv2.LINE_AA, - ) - if self.pil: - # Convert im back to PIL and update draw - self.fromarray(self.im) - - def rectangle(self, xy, fill=None, outline=None, width: int = 1): - """Add rectangle to image (PIL-only).""" - self.draw.rectangle(xy, fill, outline, width) - - def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()): - """ - Add text to an image using PIL or cv2. - - Args: - xy (List[int]): Top-left coordinates for text placement. - text (str): Text to be drawn. - txt_color (tuple, optional): Text color (R, G, B). - anchor (str, optional): Text anchor position ('top' or 'bottom'). - box_color (tuple, optional): Box color (R, G, B, A) with optional alpha. - """ - if self.pil: - w, h = self.font.getsize(text) - if anchor == "bottom": # start y from font bottom - xy[1] += 1 - h - for line in text.split("\n"): - if box_color: - # Draw rectangle for each line - w, h = self.font.getsize(line) - self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=box_color) - self.draw.text(xy, line, fill=txt_color, font=self.font) - xy[1] += h - else: - if box_color: - w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0] - h += 3 # add pixels to pad text - outside = xy[1] >= h # label fits outside box - p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h - cv2.rectangle(self.im, xy, p2, box_color, -1, cv2.LINE_AA) # filled - cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA) - - def fromarray(self, im): - """Update self.im from a numpy array.""" - self.im = im if isinstance(im, Image.Image) else Image.fromarray(im) - self.draw = ImageDraw.Draw(self.im) - - def result(self): - """Return annotated image as array.""" - return np.asarray(self.im) - - def show(self, title: Optional[str] = None): - """Show the annotated image.""" - im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR - if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments - try: - display(im) # noqa - display() function only available in ipython environments - except ImportError as e: - LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}") - else: - im.show(title=title) - - def save(self, filename: str = "image.jpg"): - """Save the annotated image to 'filename'.""" - cv2.imwrite(filename, np.asarray(self.im)) - - @staticmethod - def get_bbox_dimension(bbox: Optional[tuple] = None): - """ - Calculate the dimensions and area of a bounding box. - - Args: - bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max). - - Returns: - width (float): Width of the bounding box. - height (float): Height of the bounding box. - area (float): Area enclosed by the bounding box. - - Examples: - >>> from ultralytics.utils.plotting import Annotator - >>> im0 = cv2.imread("test.png") - >>> annotator = Annotator(im0, line_width=10) - >>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40]) - """ - x_min, y_min, x_max, y_max = bbox - width = x_max - x_min - height = y_max - y_min - return width, height, width * height - - -@TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395 -@plt_settings() -def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): - """ - Plot training labels including class histograms and box statistics. - - Args: - boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height]. - cls (np.ndarray): Class indices. - names (dict, optional): Dictionary mapping class indices to class names. - save_dir (Path, optional): Directory to save the plot. - on_plot (Callable, optional): Function to call after plot is saved. - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - import pandas - from matplotlib.colors import LinearSegmentedColormap - - # Filter matplotlib>=3.7.2 warning - warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight") - warnings.filterwarnings("ignore", category=FutureWarning) - - # Plot dataset labels - LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") - nc = int(cls.max() + 1) # number of classes - boxes = boxes[:1000000] # limit to 1M boxes - x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"]) - - try: # Seaborn correlogram - import seaborn - - seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) - plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200) - plt.close() - except ImportError: - pass # Skip if seaborn is not installed - - # Matplotlib labels - subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"]) - ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() - y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) - for i in range(nc): - y[2].patches[i].set_color([x / 255 for x in colors(i)]) - ax[0].set_ylabel("instances") - if 0 < len(names) < 30: - ax[0].set_xticks(range(len(names))) - ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10) - else: - ax[0].set_xlabel("classes") - boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000 - img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255) - for cls, box in zip(cls[:500], boxes[:500]): - ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot - ax[1].imshow(img) - ax[1].axis("off") - - ax[2].hist2d(x["x"], x["y"], bins=50, cmap=subplot_3_4_color) - ax[2].set_xlabel("x") - ax[2].set_ylabel("y") - ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color) - ax[3].set_xlabel("width") - ax[3].set_ylabel("height") - for a in {0, 1, 2, 3}: - for s in {"top", "right", "left", "bottom"}: - ax[a].spines[s].set_visible(False) - - fname = save_dir / "labels.jpg" - plt.savefig(fname, dpi=200) - plt.close() - if on_plot: - on_plot(fname) - - -def save_one_box( - xyxy, - im, - file: Path = Path("im.jpg"), - gain: float = 1.02, - pad: int = 10, - square: bool = False, - BGR: bool = False, - save: bool = True, -): - """ - Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop. - - This function takes a bounding box and an image, and then saves a cropped portion of the image according - to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding - adjustments to the bounding box. - - Args: - xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format. - im (np.ndarray): The input image. - file (Path, optional): The path where the cropped image will be saved. - gain (float, optional): A multiplicative factor to increase the size of the bounding box. - pad (int, optional): The number of pixels to add to the width and height of the bounding box. - square (bool, optional): If True, the bounding box will be transformed into a square. - BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB. - save (bool, optional): If True, the cropped image will be saved to disk. - - Returns: - (np.ndarray): The cropped image. - - Examples: - >>> from ultralytics.utils.plotting import save_one_box - >>> xyxy = [50, 50, 150, 150] - >>> im = cv2.imread("image.jpg") - >>> cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True) - """ - if not isinstance(xyxy, torch.Tensor): # may be list - xyxy = torch.stack(xyxy) - b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes - if square: - b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square - b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad - xyxy = ops.xywh2xyxy(b).long() - xyxy = ops.clip_boxes(xyxy, im.shape) - grayscale = im.shape[2] == 1 # grayscale image - crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)] - if save: - file.parent.mkdir(parents=True, exist_ok=True) # make directory - f = str(increment_path(file).with_suffix(".jpg")) - # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue - crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop - Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB - return crop - - -@threaded -def plot_images( - labels: Dict[str, Any], - images: Union[torch.Tensor, np.ndarray] = np.zeros((0, 3, 640, 640), dtype=np.float32), - paths: Optional[List[str]] = None, - fname: str = "images.jpg", - names: Optional[Dict[int, str]] = None, - on_plot: Optional[Callable] = None, - max_size: int = 1920, - max_subplots: int = 16, - save: bool = True, - conf_thres: float = 0.25, -) -> Optional[np.ndarray]: - """ - Plot image grid with labels, bounding boxes, masks, and keypoints. - - Args: - labels (Dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'. - images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width). - paths (Optional[List[str]]): List of file paths for each image in the batch. - fname (str): Output filename for the plotted image grid. - names (Optional[Dict[int, str]]): Dictionary mapping class indices to class names. - on_plot (Optional[Callable]): Optional callback function to be called after saving the plot. - max_size (int): Maximum size of the output image grid. - max_subplots (int): Maximum number of subplots in the image grid. - save (bool): Whether to save the plotted image grid to a file. - conf_thres (float): Confidence threshold for displaying detections. - - Returns: - (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise. - - Note: - This function supports both tensor and numpy array inputs. It will automatically - convert tensor inputs to numpy arrays for processing. - """ - for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}: - if k not in labels: - continue - if k == "cls" and labels[k].ndim == 2: - labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1) - if isinstance(labels[k], torch.Tensor): - labels[k] = labels[k].cpu().numpy() - - cls = labels.get("cls", np.zeros(0, dtype=np.int64)) - batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64)) - bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32)) - confs = labels.get("conf", None) - masks = labels.get("masks", np.zeros(0, dtype=np.uint8)) - kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32)) - images = labels.get("img", images) # default to input images - - if len(images) and isinstance(images, torch.Tensor): - images = images.cpu().float().numpy() - if images.shape[1] > 3: - images = images[:, :3] # crop multispectral images to first 3 channels - - bs, _, h, w = images.shape # batch size, _, height, width - bs = min(bs, max_subplots) # limit plot images - ns = np.ceil(bs**0.5) # number of subplots (square) - if np.max(images[0]) <= 1: - images *= 255 # de-normalise (optional) - - # Build Image - mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init - for i in range(bs): - x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin - mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0) - - # Resize (optional) - scale = max_size / ns / max(h, w) - if scale < 1: - h = math.ceil(scale * h) - w = math.ceil(scale * w) - mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) - - # Annotate - fs = int((h + w) * ns * 0.01) # font size - fs = max(fs, 18) # ensure that the font size is large enough to be easily readable. - annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=str(names)) - for i in range(bs): - x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin - annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders - if paths: - annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames - if len(cls) > 0: - idx = batch_idx == i - classes = cls[idx].astype("int") - labels = confs is None - - if len(bboxes): - boxes = bboxes[idx] - conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred) - if len(boxes): - if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1 - boxes[..., [0, 2]] *= w # scale to pixels - boxes[..., [1, 3]] *= h - elif scale < 1: # absolute coords need scale if image scales - boxes[..., :4] *= scale - boxes[..., 0] += x - boxes[..., 1] += y - is_obb = boxes.shape[-1] == 5 # xywhr - # TODO: this transformation might be unnecessary - boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes) - for j, box in enumerate(boxes.astype(np.int64).tolist()): - c = classes[j] - color = colors(c) - c = names.get(c, c) if names else c - if labels or conf[j] > conf_thres: - label = f"{c}" if labels else f"{c} {conf[j]:.1f}" - annotator.box_label(box, label, color=color) - - elif len(classes): - for c in classes: - color = colors(c) - c = names.get(c, c) if names else c - annotator.text([x, y], f"{c}", txt_color=color, box_color=(64, 64, 64, 128)) - - # Plot keypoints - if len(kpts): - kpts_ = kpts[idx].copy() - if len(kpts_): - if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01 - kpts_[..., 0] *= w # scale to pixels - kpts_[..., 1] *= h - elif scale < 1: # absolute coords need scale if image scales - kpts_ *= scale - kpts_[..., 0] += x - kpts_[..., 1] += y - for j in range(len(kpts_)): - if labels or conf[j] > conf_thres: - annotator.kpts(kpts_[j], conf_thres=conf_thres) - - # Plot masks - if len(masks): - if idx.shape[0] == masks.shape[0]: # overlap_mask=False - image_masks = masks[idx] - else: # overlap_mask=True - image_masks = masks[[i]] # (1, 640, 640) - nl = idx.sum() - index = np.arange(nl).reshape((nl, 1, 1)) + 1 - image_masks = np.repeat(image_masks, nl, axis=0) - image_masks = np.where(image_masks == index, 1.0, 0.0) - - im = np.asarray(annotator.im).copy() - for j in range(len(image_masks)): - if labels or conf[j] > conf_thres: - color = colors(classes[j]) - mh, mw = image_masks[j].shape - if mh != h or mw != w: - mask = image_masks[j].astype(np.uint8) - mask = cv2.resize(mask, (w, h)) - mask = mask.astype(bool) - else: - mask = image_masks[j].astype(bool) - try: - im[y : y + h, x : x + w, :][mask] = ( - im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6 - ) - except Exception: - pass - annotator.fromarray(im) - if not save: - return np.asarray(annotator.im) - annotator.im.save(fname) # save - if on_plot: - on_plot(fname) - - -@plt_settings() -def plot_results( - file: str = "path/to/results.csv", - dir: str = "", - segment: bool = False, - pose: bool = False, - classify: bool = False, - on_plot: Optional[Callable] = None, -): - """ - Plot training results from a results CSV file. The function supports various types of data including segmentation, - pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located. - - Args: - file (str, optional): Path to the CSV file containing the training results. - dir (str, optional): Directory where the CSV file is located if 'file' is not provided. - segment (bool, optional): Flag to indicate if the data is for segmentation. - pose (bool, optional): Flag to indicate if the data is for pose estimation. - classify (bool, optional): Flag to indicate if the data is for classification. - on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument. - - Examples: - >>> from ultralytics.utils.plotting import plot_results - >>> plot_results("path/to/results.csv", segment=True) - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - import pandas as pd - from scipy.ndimage import gaussian_filter1d - - save_dir = Path(file).parent if file else Path(dir) - if classify: - fig, ax = plt.subplots(2, 2, figsize=(6, 6), tight_layout=True) - index = [2, 5, 3, 4] - elif segment: - fig, ax = plt.subplots(2, 8, figsize=(18, 6), tight_layout=True) - index = [2, 3, 4, 5, 6, 7, 10, 11, 14, 15, 16, 17, 8, 9, 12, 13] - elif pose: - fig, ax = plt.subplots(2, 9, figsize=(21, 6), tight_layout=True) - index = [2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 17, 18, 19, 9, 10, 13, 14] - else: - fig, ax = plt.subplots(2, 5, figsize=(12, 6), tight_layout=True) - index = [2, 3, 4, 5, 6, 9, 10, 11, 7, 8] - ax = ax.ravel() - files = list(save_dir.glob("results*.csv")) - assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot." - for f in files: - try: - data = pd.read_csv(f) - s = [x.strip() for x in data.columns] - x = data.values[:, 0] - for i, j in enumerate(index): - y = data.values[:, j].astype("float") - # y[y == 0] = np.nan # don't show zero values - ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results - ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line - ax[i].set_title(s[j], fontsize=12) - # if j in {8, 9, 10}: # share train and val loss y axes - # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5]) - except Exception as e: - LOGGER.error(f"Plotting error for {f}: {e}") - ax[1].legend() - fname = save_dir / "results.png" - fig.savefig(fname, dpi=200) - plt.close() - if on_plot: - on_plot(fname) - - -def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"): - """ - Plot a scatter plot with points colored based on a 2D histogram. - - Args: - v (array-like): Values for the x-axis. - f (array-like): Values for the y-axis. - bins (int, optional): Number of bins for the histogram. - cmap (str, optional): Colormap for the scatter plot. - alpha (float, optional): Alpha for the scatter plot. - edgecolors (str, optional): Edge colors for the scatter plot. - - Examples: - >>> v = np.random.rand(100) - >>> f = np.random.rand(100) - >>> plt_color_scatter(v, f) - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - - # Calculate 2D histogram and corresponding colors - hist, xedges, yedges = np.histogram2d(v, f, bins=bins) - colors = [ - hist[ - min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1), - min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1), - ] - for i in range(len(v)) - ] - - # Scatter plot - plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors) - - -def plot_tune_results(csv_file: str = "tune_results.csv"): - """ - Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key - in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots. - - Args: - csv_file (str, optional): Path to the CSV file containing the tuning results. - - Examples: - >>> plot_tune_results("path/to/tune_results.csv") - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - import pandas as pd - from scipy.ndimage import gaussian_filter1d - - def _save_one_file(file): - """Save one matplotlib plot to 'file'.""" - plt.savefig(file, dpi=200) - plt.close() - LOGGER.info(f"Saved {file}") - - # Scatter plots for each hyperparameter - csv_file = Path(csv_file) - data = pd.read_csv(csv_file) - num_metrics_columns = 1 - keys = [x.strip() for x in data.columns][num_metrics_columns:] - x = data.values - fitness = x[:, 0] # fitness - j = np.argmax(fitness) # max fitness index - n = math.ceil(len(keys) ** 0.5) # columns and rows in plot - plt.figure(figsize=(10, 10), tight_layout=True) - for i, k in enumerate(keys): - v = x[:, i + num_metrics_columns] - mu = v[j] # best single result - plt.subplot(n, n, i + 1) - plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none") - plt.plot(mu, fitness.max(), "k+", markersize=15) - plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters - plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8 - if i % n != 0: - plt.yticks([]) - _save_one_file(csv_file.with_name("tune_scatter_plots.png")) - - # Fitness vs iteration - x = range(1, len(fitness) + 1) - plt.figure(figsize=(10, 6), tight_layout=True) - plt.plot(x, fitness, marker="o", linestyle="none", label="fitness") - plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line - plt.title("Fitness vs Iteration") - plt.xlabel("Iteration") - plt.ylabel("Fitness") - plt.grid(True) - plt.legend() - _save_one_file(csv_file.with_name("tune_fitness.png")) - - -def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")): - """ - Visualize feature maps of a given model module during inference. - - Args: - x (torch.Tensor): Features to be visualized. - module_type (str): Module type. - stage (int): Module stage within the model. - n (int, optional): Maximum number of feature maps to plot. - save_dir (Path, optional): Directory to save results. - """ - import matplotlib.pyplot as plt # scope for faster 'import ultralytics' - - for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads - if m in module_type: - return - if isinstance(x, torch.Tensor): - _, channels, height, width = x.shape # batch, channels, height, width - if height > 1 and width > 1: - f = save_dir / f"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png" # filename - - blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels - n = min(n, channels) # number of plots - _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols - ax = ax.ravel() - plt.subplots_adjust(wspace=0.05, hspace=0.05) - for i in range(n): - ax[i].imshow(blocks[i].squeeze()) # cmap='gray' - ax[i].axis("off") - - LOGGER.info(f"Saving {f}... ({n}/{channels})") - plt.savefig(f, dpi=300, bbox_inches="tight") - plt.close() - np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/tal.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/tal.py deleted file mode 100644 index 3a2091f..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/tal.py +++ /dev/null @@ -1,419 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import torch -import torch.nn as nn - -from . import LOGGER -from .checks import check_version -from .metrics import bbox_iou, probiou -from .ops import xywhr2xyxyxyxy - -TORCH_1_10 = check_version(torch.__version__, "1.10.0") - - -class TaskAlignedAssigner(nn.Module): - """ - A task-aligned assigner for object detection. - - This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both - classification and localization information. - - Attributes: - topk (int): The number of top candidates to consider. - num_classes (int): The number of object classes. - alpha (float): The alpha parameter for the classification component of the task-aligned metric. - beta (float): The beta parameter for the localization component of the task-aligned metric. - eps (float): A small value to prevent division by zero. - """ - - def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9): - """ - Initialize a TaskAlignedAssigner object with customizable hyperparameters. - - Args: - topk (int, optional): The number of top candidates to consider. - num_classes (int, optional): The number of object classes. - alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric. - beta (float, optional): The beta parameter for the localization component of the task-aligned metric. - eps (float, optional): A small value to prevent division by zero. - """ - super().__init__() - self.topk = topk - self.num_classes = num_classes - self.alpha = alpha - self.beta = beta - self.eps = eps - - @torch.no_grad() - def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): - """ - Compute the task-aligned assignment. - - Args: - pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). - pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). - anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2). - gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). - gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). - mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1). - - Returns: - target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors). - target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4). - target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes). - fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors). - target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors). - - References: - https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py - """ - self.bs = pd_scores.shape[0] - self.n_max_boxes = gt_bboxes.shape[1] - device = gt_bboxes.device - - if self.n_max_boxes == 0: - return ( - torch.full_like(pd_scores[..., 0], self.num_classes), - torch.zeros_like(pd_bboxes), - torch.zeros_like(pd_scores), - torch.zeros_like(pd_scores[..., 0]), - torch.zeros_like(pd_scores[..., 0]), - ) - - try: - return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt) - except torch.cuda.OutOfMemoryError: - # Move tensors to CPU, compute, then move back to original device - LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU") - cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)] - result = self._forward(*cpu_tensors) - return tuple(t.to(device) for t in result) - - def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): - """ - Compute the task-aligned assignment. - - Args: - pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). - pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). - anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2). - gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). - gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). - mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1). - - Returns: - target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors). - target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4). - target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes). - fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors). - target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors). - """ - mask_pos, align_metric, overlaps = self.get_pos_mask( - pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt - ) - - target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes) - - # Assigned target - target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask) - - # Normalize - align_metric *= mask_pos - pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj - pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj - norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) - target_scores = target_scores * norm_align_metric - - return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx - - def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt): - """ - Get positive mask for each ground truth box. - - Args: - pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). - pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). - gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). - gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). - anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2). - mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1). - - Returns: - mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w). - align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w). - overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w). - """ - mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes) - # Get anchor_align metric, (b, max_num_obj, h*w) - align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt) - # Get topk_metric mask, (b, max_num_obj, h*w) - mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool()) - # Merge all mask to a final mask, (b, max_num_obj, h*w) - mask_pos = mask_topk * mask_in_gts * mask_gt - - return mask_pos, align_metric, overlaps - - def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt): - """ - Compute alignment metric given predicted and ground truth bounding boxes. - - Args: - pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes). - pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4). - gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1). - gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4). - mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w). - - Returns: - align_metric (torch.Tensor): Alignment metric combining classification and localization. - overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes. - """ - na = pd_bboxes.shape[-2] - mask_gt = mask_gt.bool() # b, max_num_obj, h*w - overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device) - bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device) - - ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj - ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj - ind[1] = gt_labels.squeeze(-1) # b, max_num_obj - # Get the scores of each grid for each gt cls - bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w - - # (b, max_num_obj, 1, 4), (b, 1, h*w, 4) - pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt] - gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt] - overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes) - - align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) - return align_metric, overlaps - - def iou_calculation(self, gt_bboxes, pd_bboxes): - """ - Calculate IoU for horizontal bounding boxes. - - Args: - gt_bboxes (torch.Tensor): Ground truth boxes. - pd_bboxes (torch.Tensor): Predicted boxes. - - Returns: - (torch.Tensor): IoU values between each pair of boxes. - """ - return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0) - - def select_topk_candidates(self, metrics, topk_mask=None): - """ - Select the top-k candidates based on the given metrics. - - Args: - metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is - the maximum number of objects, and h*w represents the total number of anchor points. - topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where - topk is the number of top candidates to consider. If not provided, the top-k values are automatically - computed based on the given metrics. - - Returns: - (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates. - """ - # (b, max_num_obj, topk) - topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True) - if topk_mask is None: - topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs) - # (b, max_num_obj, topk) - topk_idxs.masked_fill_(~topk_mask, 0) - - # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w) - count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device) - ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device) - for k in range(self.topk): - # Expand topk_idxs for each value of k and add 1 at the specified positions - count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones) - # Filter invalid bboxes - count_tensor.masked_fill_(count_tensor > 1, 0) - - return count_tensor.to(metrics.dtype) - - def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): - """ - Compute target labels, target bounding boxes, and target scores for the positive anchor points. - - Args: - gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the - batch size and max_num_obj is the maximum number of objects. - gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4). - target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive - anchor points, with shape (b, h*w), where h*w is the total - number of anchor points. - fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive - (foreground) anchor points. - - Returns: - target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w). - target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4). - target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes). - """ - # Assigned target labels, (b, 1) - batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None] - target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w) - target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w) - - # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4) - target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx] - - # Assigned target scores - target_labels.clamp_(0) - - # 10x faster than F.one_hot() - target_scores = torch.zeros( - (target_labels.shape[0], target_labels.shape[1], self.num_classes), - dtype=torch.int64, - device=target_labels.device, - ) # (b, h*w, 80) - target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) - - fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80) - target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) - - return target_labels, target_bboxes, target_scores - - @staticmethod - def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): - """ - Select positive anchor centers within ground truth bounding boxes. - - Args: - xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2). - gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4). - eps (float, optional): Small value for numerical stability. - - Returns: - (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w). - - Note: - b: batch size, n_boxes: number of ground truth boxes, h: height, w: width. - Bounding box format: [x_min, y_min, x_max, y_max]. - """ - n_anchors = xy_centers.shape[0] - bs, n_boxes, _ = gt_bboxes.shape - lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom - bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1) - return bbox_deltas.amin(3).gt_(eps) - - @staticmethod - def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): - """ - Select anchor boxes with highest IoU when assigned to multiple ground truths. - - Args: - mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w). - overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w). - n_max_boxes (int): Maximum number of ground truth boxes. - - Returns: - target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w). - fg_mask (torch.Tensor): Foreground mask, shape (b, h*w). - mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w). - """ - # Convert (b, n_max_boxes, h*w) -> (b, h*w) - fg_mask = mask_pos.sum(-2) - if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes - mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w) - max_overlaps_idx = overlaps.argmax(1) # (b, h*w) - - is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) - is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) - - mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w) - fg_mask = mask_pos.sum(-2) - # Find each grid serve which gt(index) - target_gt_idx = mask_pos.argmax(-2) # (b, h*w) - return target_gt_idx, fg_mask, mask_pos - - -class RotatedTaskAlignedAssigner(TaskAlignedAssigner): - """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric.""" - - def iou_calculation(self, gt_bboxes, pd_bboxes): - """Calculate IoU for rotated bounding boxes.""" - return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0) - - @staticmethod - def select_candidates_in_gts(xy_centers, gt_bboxes): - """ - Select the positive anchor center in gt for rotated bounding boxes. - - Args: - xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2). - gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5). - - Returns: - (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w). - """ - # (b, n_boxes, 5) --> (b, n_boxes, 4, 2) - corners = xywhr2xyxyxyxy(gt_bboxes) - # (b, n_boxes, 1, 2) - a, b, _, d = corners.split(1, dim=-2) - ab = b - a - ad = d - a - - # (b, n_boxes, h*w, 2) - ap = xy_centers - a - norm_ab = (ab * ab).sum(dim=-1) - norm_ad = (ad * ad).sum(dim=-1) - ap_dot_ab = (ap * ab).sum(dim=-1) - ap_dot_ad = (ap * ad).sum(dim=-1) - return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box - - -def make_anchors(feats, strides, grid_cell_offset=0.5): - """Generate anchors from features.""" - anchor_points, stride_tensor = [], [] - assert feats is not None - dtype, device = feats[0].dtype, feats[0].device - for i, stride in enumerate(strides): - h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1])) - sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x - sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y - sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx) - anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2)) - stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) - return torch.cat(anchor_points), torch.cat(stride_tensor) - - -def dist2bbox(distance, anchor_points, xywh=True, dim=-1): - """Transform distance(ltrb) to box(xywh or xyxy).""" - lt, rb = distance.chunk(2, dim) - x1y1 = anchor_points - lt - x2y2 = anchor_points + rb - if xywh: - c_xy = (x1y1 + x2y2) / 2 - wh = x2y2 - x1y1 - return torch.cat((c_xy, wh), dim) # xywh bbox - return torch.cat((x1y1, x2y2), dim) # xyxy bbox - - -def bbox2dist(anchor_points, bbox, reg_max): - """Transform bbox(xyxy) to dist(ltrb).""" - x1y1, x2y2 = bbox.chunk(2, -1) - return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb) - - -def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1): - """ - Decode predicted rotated bounding box coordinates from anchor points and distribution. - - Args: - pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4). - pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1). - anchor_points (torch.Tensor): Anchor points with shape (h*w, 2). - dim (int, optional): Dimension along which to split. - - Returns: - (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4). - """ - lt, rb = pred_dist.split(2, dim=dim) - cos, sin = torch.cos(pred_angle), torch.sin(pred_angle) - # (bs, h*w, 1) - xf, yf = ((rb - lt) / 2).split(1, dim=dim) - x, y = xf * cos - yf * sin, xf * sin + yf * cos - xy = torch.cat([x, y], dim=dim) + anchor_points - return torch.cat([xy, lt + rb], dim=dim) diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/torch_utils.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/torch_utils.py deleted file mode 100644 index 73a6633..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/torch_utils.py +++ /dev/null @@ -1,997 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -import functools -import gc -import math -import os -import random -import time -from contextlib import contextmanager -from copy import deepcopy -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Union - -import numpy as np -import torch -import torch.distributed as dist -import torch.nn as nn -import torch.nn.functional as F - -from ultralytics import __version__ -from ultralytics.utils import ( - DEFAULT_CFG_DICT, - DEFAULT_CFG_KEYS, - LOGGER, - NUM_THREADS, - PYTHON_VERSION, - TORCHVISION_VERSION, - WINDOWS, - colorstr, -) -from ultralytics.utils.checks import check_version -from ultralytics.utils.patches import torch_load - -# Version checks (all default to version>=min_version) -TORCH_1_9 = check_version(torch.__version__, "1.9.0") -TORCH_1_13 = check_version(torch.__version__, "1.13.0") -TORCH_2_0 = check_version(torch.__version__, "2.0.0") -TORCH_2_4 = check_version(torch.__version__, "2.4.0") -TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0") -TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0") -TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0") -TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0") -if WINDOWS and check_version(torch.__version__, "==2.4.0"): # reject version 2.4.0 on Windows - LOGGER.warning( - "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve " - "https://github.com/ultralytics/ultralytics/issues/15049" - ) - - -@contextmanager -def torch_distributed_zero_first(local_rank: int): - """Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first.""" - initialized = dist.is_available() and dist.is_initialized() - use_ids = initialized and dist.get_backend() == "nccl" - - if initialized and local_rank not in {-1, 0}: - dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier() - yield - if initialized and local_rank == 0: - dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier() - - -def smart_inference_mode(): - """Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator.""" - - def decorate(fn): - """Apply appropriate torch decorator for inference mode based on torch version.""" - if TORCH_1_9 and torch.is_inference_mode_enabled(): - return fn # already in inference_mode, act as a pass-through - else: - return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn) - - return decorate - - -def autocast(enabled: bool, device: str = "cuda"): - """ - Get the appropriate autocast context manager based on PyTorch version and AMP setting. - - This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both - older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions. - - Args: - enabled (bool): Whether to enable automatic mixed precision. - device (str, optional): The device to use for autocast. - - Returns: - (torch.amp.autocast): The appropriate autocast context manager. - - Notes: - - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`. - - For older versions, it uses `torch.cuda.autocast`. - - Examples: - >>> with autocast(enabled=True): - ... # Your mixed precision operations here - ... pass - """ - if TORCH_1_13: - return torch.amp.autocast(device, enabled=enabled) - else: - return torch.cuda.amp.autocast(enabled) - - -@functools.lru_cache -def get_cpu_info(): - """Return a string with system CPU information, i.e. 'Apple M2'.""" - from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error - - if "cpu_info" not in PERSISTENT_CACHE: - try: - import cpuinfo # pip install py-cpuinfo - - k = "brand_raw", "hardware_raw", "arch_string_raw" # keys sorted by preference - info = cpuinfo.get_cpu_info() # info dict - string = info.get(k[0] if k[0] in info else k[1] if k[1] in info else k[2], "unknown") - PERSISTENT_CACHE["cpu_info"] = string.replace("(R)", "").replace("CPU ", "").replace("@ ", "") - except Exception: - pass - return PERSISTENT_CACHE.get("cpu_info", "unknown") - - -@functools.lru_cache -def get_gpu_info(index): - """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'.""" - properties = torch.cuda.get_device_properties(index) - return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB" - - -def select_device(device="", batch=0, newline=False, verbose=True): - """ - Select the appropriate PyTorch device based on the provided arguments. - - The function takes a string specifying the device or a torch.device object and returns a torch.device object - representing the selected device. The function also validates the number of available devices and raises an - exception if the requested device(s) are not available. - - Args: - device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or - 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available. - batch (int, optional): Batch size being used in your model. - newline (bool, optional): If True, adds a newline at the end of the log string. - verbose (bool, optional): If True, logs the device information. - - Returns: - (torch.device): Selected device. - - Raises: - ValueError: If the specified device is not available or if the batch size is not a multiple of the number of - devices when using multiple GPUs. - - Examples: - >>> select_device("cuda:0") - device(type='cuda', index=0) - - >>> select_device("cpu") - device(type='cpu') - - Notes: - Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use. - """ - if isinstance(device, torch.device) or str(device).startswith(("tpu", "intel")): - return device - - s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{torch.__version__} " - device = str(device).lower() - for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ": - device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1' - - # Auto-select GPUs - if "-1" in device: - from ultralytics.utils.autodevice import GPUInfo - - # Replace each -1 with a selected GPU or remove it - parts = device.split(",") - selected = GPUInfo().select_idle_gpu(count=parts.count("-1"), min_memory_fraction=0.2) - for i in range(len(parts)): - if parts[i] == "-1": - parts[i] = str(selected.pop(0)) if selected else "" - device = ",".join(p for p in parts if p) - - cpu = device == "cpu" - mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS) - if cpu or mps: - os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False - elif device: # non-cpu device requested - if device == "cuda": - device = "0" - if "," in device: - device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1" - visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) - os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available() - if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))): - LOGGER.info(s) - install = ( - "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no " - "CUDA devices are seen by torch.\n" - if torch.cuda.device_count() == 0 - else "" - ) - raise ValueError( - f"Invalid CUDA 'device={device}' requested." - f" Use 'device=cpu' or pass valid CUDA device(s) if available," - f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n" - f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}" - f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}" - f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n" - f"{install}" - ) - - if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available - devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"] - n = len(devices) # device count - if n > 1: # multi-GPU - if batch < 1: - raise ValueError( - "AutoBatch with batch<1 not supported for Multi-GPU training, " - f"please specify a valid batch size multiple of GPU count {n}, i.e. batch={n * 8}." - ) - if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count - raise ValueError( - f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or " - f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}." - ) - space = " " * (len(s) + 1) - for i, d in enumerate(devices): - s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB - arg = "cuda:0" - elif mps and TORCH_2_0 and torch.backends.mps.is_available(): - # Prefer MPS if available - s += f"MPS ({get_cpu_info()})\n" - arg = "mps" - else: # revert to CPU - s += f"CPU ({get_cpu_info()})\n" - arg = "cpu" - - if arg in {"cpu", "mps"}: - torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training - if verbose: - LOGGER.info(s if newline else s.rstrip()) - return torch.device(arg) - - -def time_sync(): - """Return PyTorch-accurate time.""" - if torch.cuda.is_available(): - torch.cuda.synchronize() - return time.time() - - -def fuse_conv_and_bn(conv, bn): - """Fuse Conv2d() and BatchNorm2d() layers.""" - fusedconv = ( - nn.Conv2d( - conv.in_channels, - conv.out_channels, - kernel_size=conv.kernel_size, - stride=conv.stride, - padding=conv.padding, - dilation=conv.dilation, - groups=conv.groups, - bias=True, - ) - .requires_grad_(False) - .to(conv.weight.device) - ) - - # Prepare filters - w_conv = conv.weight.view(conv.out_channels, -1) - w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape)) - - # Prepare spatial bias - b_conv = ( - torch.zeros(conv.weight.shape[0], dtype=conv.weight.dtype, device=conv.weight.device) - if conv.bias is None - else conv.bias - ) - b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) - fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) - - return fusedconv - - -def fuse_deconv_and_bn(deconv, bn): - """Fuse ConvTranspose2d() and BatchNorm2d() layers.""" - fuseddconv = ( - nn.ConvTranspose2d( - deconv.in_channels, - deconv.out_channels, - kernel_size=deconv.kernel_size, - stride=deconv.stride, - padding=deconv.padding, - output_padding=deconv.output_padding, - dilation=deconv.dilation, - groups=deconv.groups, - bias=True, - ) - .requires_grad_(False) - .to(deconv.weight.device) - ) - - # Prepare filters - w_deconv = deconv.weight.view(deconv.out_channels, -1) - w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) - fuseddconv.weight.copy_(torch.mm(w_bn, w_deconv).view(fuseddconv.weight.shape)) - - # Prepare spatial bias - b_conv = torch.zeros(deconv.weight.shape[1], device=deconv.weight.device) if deconv.bias is None else deconv.bias - b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) - fuseddconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) - - return fuseddconv - - -def model_info(model, detailed=False, verbose=True, imgsz=640): - """ - Print and return detailed model information layer by layer. - - Args: - model (nn.Module): Model to analyze. - detailed (bool, optional): Whether to print detailed layer information. - verbose (bool, optional): Whether to print model information. - imgsz (int | list, optional): Input image size. - - Returns: - n_l (int): Number of layers. - n_p (int): Number of parameters. - n_g (int): Number of gradients. - flops (float): GFLOPs. - """ - if not verbose: - return - n_p = get_num_params(model) # number of parameters - n_g = get_num_gradients(model) # number of gradients - layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0) - n_l = len(layers) # number of layers - if detailed: - h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}" - LOGGER.info(h) - for i, (mn, m) in enumerate(layers.items()): - mn = mn.replace("module_list.", "") - mt = m.__class__.__name__ - if len(m._parameters): - for pn, p in m.named_parameters(): - LOGGER.info( - f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}" - ) - else: # layers with no learnable params - LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}") - - flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320] - fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else "" - fs = f", {flops:.1f} GFLOPs" if flops else "" - yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "") - model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model" - LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}") - return n_l, n_p, n_g, flops - - -def get_num_params(model): - """Return the total number of parameters in a YOLO model.""" - return sum(x.numel() for x in model.parameters()) - - -def get_num_gradients(model): - """Return the total number of parameters with gradients in a YOLO model.""" - return sum(x.numel() for x in model.parameters() if x.requires_grad) - - -def model_info_for_loggers(trainer): - """ - Return model info dict with useful model information. - - Args: - trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data. - - Returns: - (dict): Dictionary containing model parameters, GFLOPs, and inference speeds. - - Examples: - YOLOv8n info for loggers - >>> results = { - ... "model/parameters": 3151904, - ... "model/GFLOPs": 8.746, - ... "model/speed_ONNX(ms)": 41.244, - ... "model/speed_TensorRT(ms)": 3.211, - ... "model/speed_PyTorch(ms)": 18.755, - ...} - """ - if trainer.args.profile: # profile ONNX and TensorRT times - from ultralytics.utils.benchmarks import ProfileModels - - results = ProfileModels([trainer.last], device=trainer.device).run()[0] - results.pop("model/name") - else: # only return PyTorch times from most recent validation - results = { - "model/parameters": get_num_params(trainer.model), - "model/GFLOPs": round(get_flops(trainer.model), 3), - } - results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3) - return results - - -def get_flops(model, imgsz=640): - """ - Calculate FLOPs (floating point operations) for a model in billions. - - Attempts two calculation methods: first with a stride-based tensor for efficiency, - then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0 - if thop library is unavailable or calculation fails. - - Args: - model (nn.Module): The model to calculate FLOPs for. - imgsz (int | list, optional): Input image size. - - Returns: - (float): The model FLOPs in billions. - """ - try: - import thop - except ImportError: - thop = None # conda support without 'ultralytics-thop' installed - - if not thop: - return 0.0 # if not installed return 0.0 GFLOPs - - try: - model = de_parallel(model) - p = next(model.parameters()) - if not isinstance(imgsz, list): - imgsz = [imgsz, imgsz] # expand if int/float - try: - # Method 1: Use stride-based input tensor - stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format - flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs - return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs - except Exception: - # Method 2: Use actual image size (required for RTDETR models) - im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format - return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs - except Exception: - return 0.0 - - -def get_flops_with_torch_profiler(model, imgsz=640): - """ - Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower). - - Args: - model (nn.Module): The model to calculate FLOPs for. - imgsz (int | list, optional): Input image size. - - Returns: - (float): The model's FLOPs in billions. - """ - if not TORCH_2_0: # torch profiler implemented in torch>=2.0 - return 0.0 - model = de_parallel(model) - p = next(model.parameters()) - if not isinstance(imgsz, list): - imgsz = [imgsz, imgsz] # expand if int/float - try: - # Use stride size for input tensor - stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride - im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format - with torch.profiler.profile(with_flops=True) as prof: - model(im) - flops = sum(x.flops for x in prof.key_averages()) / 1e9 - flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs - except Exception: - # Use actual image size for input tensor (i.e. required for RTDETR models) - im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format - with torch.profiler.profile(with_flops=True) as prof: - model(im) - flops = sum(x.flops for x in prof.key_averages()) / 1e9 - return flops - - -def initialize_weights(model): - """Initialize model weights to random values.""" - for m in model.modules(): - t = type(m) - if t is nn.Conv2d: - pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif t is nn.BatchNorm2d: - m.eps = 1e-3 - m.momentum = 0.03 - elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}: - m.inplace = True - - -def scale_img(img, ratio=1.0, same_shape=False, gs=32): - """ - Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple. - - Args: - img (torch.Tensor): Input image tensor. - ratio (float, optional): Scaling ratio. - same_shape (bool, optional): Whether to maintain the same shape. - gs (int, optional): Grid size for padding. - - Returns: - (torch.Tensor): Scaled and padded image tensor. - """ - if ratio == 1.0: - return img - h, w = img.shape[2:] - s = (int(h * ratio), int(w * ratio)) # new size - img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize - if not same_shape: # pad/crop img - h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) - return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean - - -def copy_attr(a, b, include=(), exclude=()): - """ - Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes. - - Args: - a (Any): Destination object to copy attributes to. - b (Any): Source object to copy attributes from. - include (tuple, optional): Attributes to include. If empty, all attributes are included. - exclude (tuple, optional): Attributes to exclude. - """ - for k, v in b.__dict__.items(): - if (len(include) and k not in include) or k.startswith("_") or k in exclude: - continue - else: - setattr(a, k, v) - - -def get_latest_opset(): - """ - Return the second-most recent ONNX opset version supported by this version of PyTorch, adjusted for maturity. - - Returns: - (int): The ONNX opset version. - """ - if TORCH_1_13: - # If the PyTorch>=1.13, dynamically compute the latest opset minus one using 'symbolic_opset' - return max(int(k[14:]) for k in vars(torch.onnx) if "symbolic_opset" in k) - 1 - # Otherwise for PyTorch<=1.12 return the corresponding predefined opset - version = torch.onnx.producer_version.rsplit(".", 1)[0] # i.e. '2.3' - return {"1.12": 15, "1.11": 14, "1.10": 13, "1.9": 12, "1.8": 12}.get(version, 12) - - -def intersect_dicts(da, db, exclude=()): - """ - Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values. - - Args: - da (dict): First dictionary. - db (dict): Second dictionary. - exclude (tuple, optional): Keys to exclude. - - Returns: - (dict): Dictionary of intersecting keys with matching shapes. - """ - return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape} - - -def is_parallel(model): - """ - Return True if model is of type DP or DDP. - - Args: - model (nn.Module): Model to check. - - Returns: - (bool): True if model is DataParallel or DistributedDataParallel. - """ - return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)) - - -def de_parallel(model): - """ - De-parallelize a model: return single-GPU model if model is of type DP or DDP. - - Args: - model (nn.Module): Model to de-parallelize. - - Returns: - (nn.Module): De-parallelized model. - """ - return model.module if is_parallel(model) else model - - -def one_cycle(y1=0.0, y2=1.0, steps=100): - """ - Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf. - - Args: - y1 (float, optional): Initial value. - y2 (float, optional): Final value. - steps (int, optional): Number of steps. - - Returns: - (function): Lambda function for computing the sinusoidal ramp. - """ - return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1 - - -def init_seeds(seed=0, deterministic=False): - """ - Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html. - - Args: - seed (int, optional): Random seed. - deterministic (bool, optional): Whether to set deterministic algorithms. - """ - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe - # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287 - if deterministic: - if TORCH_2_0: - torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible - torch.backends.cudnn.deterministic = True - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - os.environ["PYTHONHASHSEED"] = str(seed) - else: - LOGGER.warning("Upgrade to torch>=2.0.0 for deterministic training.") - else: - unset_deterministic() - - -def unset_deterministic(): - """Unset all the configurations applied for deterministic training.""" - torch.use_deterministic_algorithms(False) - torch.backends.cudnn.deterministic = False - os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None) - os.environ.pop("PYTHONHASHSEED", None) - - -class ModelEMA: - """ - Updated Exponential Moving Average (EMA) implementation. - - Keeps a moving average of everything in the model state_dict (parameters and buffers). - For EMA details see References. - - To disable EMA set the `enabled` attribute to `False`. - - Attributes: - ema (nn.Module): Copy of the model in evaluation mode. - updates (int): Number of EMA updates. - decay (function): Decay function that determines the EMA weight. - enabled (bool): Whether EMA is enabled. - - References: - - https://github.com/rwightman/pytorch-image-models - - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage - """ - - def __init__(self, model, decay=0.9999, tau=2000, updates=0): - """ - Initialize EMA for 'model' with given arguments. - - Args: - model (nn.Module): Model to create EMA for. - decay (float, optional): Maximum EMA decay rate. - tau (int, optional): EMA decay time constant. - updates (int, optional): Initial number of updates. - """ - self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA - self.updates = updates # number of EMA updates - self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs) - for p in self.ema.parameters(): - p.requires_grad_(False) - self.enabled = True - - def update(self, model): - """ - Update EMA parameters. - - Args: - model (nn.Module): Model to update EMA from. - """ - if self.enabled: - self.updates += 1 - d = self.decay(self.updates) - - msd = de_parallel(model).state_dict() # model state_dict - for k, v in self.ema.state_dict().items(): - if v.dtype.is_floating_point: # true for FP16 and FP32 - v *= d - v += (1 - d) * msd[k].detach() - # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}' - - def update_attr(self, model, include=(), exclude=("process_group", "reducer")): - """ - Update attributes and save stripped model with optimizer removed. - - Args: - model (nn.Module): Model to update attributes from. - include (tuple, optional): Attributes to include. - exclude (tuple, optional): Attributes to exclude. - """ - if self.enabled: - copy_attr(self.ema, model, include, exclude) - - -def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "", updates: Dict[str, Any] = None) -> Dict[str, Any]: - """ - Strip optimizer from 'f' to finalize training, optionally save as 's'. - - Args: - f (str | Path): File path to model to strip the optimizer from. - s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be - overwritten. - updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving. - - Returns: - (dict): The combined checkpoint dictionary. - - Examples: - >>> from pathlib import Path - >>> from ultralytics.utils.torch_utils import strip_optimizer - >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"): - >>> strip_optimizer(f) - """ - try: - x = torch_load(f, map_location=torch.device("cpu")) - assert isinstance(x, dict), "checkpoint is not a Python dictionary" - assert "model" in x, "'model' missing from checkpoint" - except Exception as e: - LOGGER.warning(f"Skipping {f}, not a valid Ultralytics model: {e}") - return {} - - metadata = { - "date": datetime.now().isoformat(), - "version": __version__, - "license": "AGPL-3.0 License (https://ultralytics.com/license)", - "docs": "https://docs.ultralytics.com", - } - - # Update model - if x.get("ema"): - x["model"] = x["ema"] # replace model with EMA - if hasattr(x["model"], "args"): - x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict - if hasattr(x["model"], "criterion"): - x["model"].criterion = None # strip loss criterion - x["model"].half() # to FP16 - for p in x["model"].parameters(): - p.requires_grad = False - - # Update other keys - args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args - for k in "optimizer", "best_fitness", "ema", "updates": # keys - x[k] = None - x["epoch"] = -1 - x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys - # x['model'].args = x['train_args'] - - # Save - combined = {**metadata, **x, **(updates or {})} - torch.save(combined, s or f) # combine dicts (prefer to the right) - mb = os.path.getsize(s or f) / 1e6 # file size - LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB") - return combined - - -def convert_optimizer_state_dict_to_fp16(state_dict): - """ - Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions. - - Args: - state_dict (dict): Optimizer state dictionary. - - Returns: - (dict): Converted optimizer state dictionary with FP16 tensors. - """ - for state in state_dict["state"].values(): - for k, v in state.items(): - if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32: - state[k] = v.half() - - return state_dict - - -@contextmanager -def cuda_memory_usage(device=None): - """ - Monitor and manage CUDA memory usage. - - This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory. - It then yields a dictionary containing memory usage information, which can be updated by the caller. - Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device. - - Args: - device (torch.device, optional): The CUDA device to query memory usage for. - - Yields: - (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory. - """ - cuda_info = dict(memory=0) - if torch.cuda.is_available(): - torch.cuda.empty_cache() - try: - yield cuda_info - finally: - cuda_info["memory"] = torch.cuda.memory_reserved(device) - else: - yield cuda_info - - -def profile_ops(input, ops, n=10, device=None, max_num_obj=0): - """ - Ultralytics speed, memory and FLOPs profiler. - - Args: - input (torch.Tensor | list): Input tensor(s) to profile. - ops (nn.Module | list): Model or list of operations to profile. - n (int, optional): Number of iterations to average. - device (str | torch.device, optional): Device to profile on. - max_num_obj (int, optional): Maximum number of objects for simulation. - - Returns: - (list): Profile results for each operation. - - Examples: - >>> from ultralytics.utils.torch_utils import profile_ops - >>> input = torch.randn(16, 3, 640, 640) - >>> m1 = lambda x: x * torch.sigmoid(x) - >>> m2 = nn.SiLU() - >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations - """ - try: - import thop - except ImportError: - thop = None # conda support without 'ultralytics-thop' installed - - results = [] - if not isinstance(device, torch.device): - device = select_device(device) - LOGGER.info( - f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" - f"{'input':>24s}{'output':>24s}" - ) - gc.collect() # attempt to free unused memory - torch.cuda.empty_cache() - for x in input if isinstance(input, list) else [input]: - x = x.to(device) - x.requires_grad = True - for m in ops if isinstance(ops, list) else [ops]: - m = m.to(device) if hasattr(m, "to") else m # device - m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m - tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward - try: - flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs - except Exception: - flops = 0 - - try: - mem = 0 - for _ in range(n): - with cuda_memory_usage(device) as cuda_info: - t[0] = time_sync() - y = m(x) - t[1] = time_sync() - try: - (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() - t[2] = time_sync() - except Exception: # no backward method - # print(e) # for debug - t[2] = float("nan") - mem += cuda_info["memory"] / 1e9 # (GB) - tf += (t[1] - t[0]) * 1000 / n # ms per op forward - tb += (t[2] - t[1]) * 1000 / n # ms per op backward - if max_num_obj: # simulate training with predictions per image grid (for AutoBatch) - with cuda_memory_usage(device) as cuda_info: - torch.randn( - x.shape[0], - max_num_obj, - int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())), - device=device, - dtype=torch.float32, - ) - mem += cuda_info["memory"] / 1e9 # (GB) - s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes - p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters - LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}") - results.append([p, flops, mem, tf, tb, s_in, s_out]) - except Exception as e: - LOGGER.info(e) - results.append(None) - finally: - gc.collect() # attempt to free unused memory - torch.cuda.empty_cache() - return results - - -class EarlyStopping: - """ - Early stopping class that stops training when a specified number of epochs have passed without improvement. - - Attributes: - best_fitness (float): Best fitness value observed. - best_epoch (int): Epoch where best fitness was observed. - patience (int): Number of epochs to wait after fitness stops improving before stopping. - possible_stop (bool): Flag indicating if stopping may occur next epoch. - """ - - def __init__(self, patience=50): - """ - Initialize early stopping object. - - Args: - patience (int, optional): Number of epochs to wait after fitness stops improving before stopping. - """ - self.best_fitness = 0.0 # i.e. mAP - self.best_epoch = 0 - self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop - self.possible_stop = False # possible stop may occur next epoch - - def __call__(self, epoch, fitness): - """ - Check whether to stop training. - - Args: - epoch (int): Current epoch of training - fitness (float): Fitness value of current epoch - - Returns: - (bool): True if training should stop, False otherwise - """ - if fitness is None: # check if fitness=None (happens when val=False) - return False - - if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training - self.best_epoch = epoch - self.best_fitness = fitness - delta = epoch - self.best_epoch # epochs without improvement - self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch - stop = delta >= self.patience # stop training if patience exceeded - if stop: - prefix = colorstr("EarlyStopping: ") - LOGGER.info( - f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. " - f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n" - f"To update EarlyStopping(patience={self.patience}) pass a new patience value, " - f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping." - ) - return stop - - -class FXModel(nn.Module): - """ - A custom model class for torch.fx compatibility. - - This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph - manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper - copying. - - Attributes: - model (nn.Module): The original model's layers. - """ - - def __init__(self, model): - """ - Initialize the FXModel. - - Args: - model (nn.Module): The original model to wrap for torch.fx compatibility. - """ - super().__init__() - copy_attr(self, model) - # Explicitly set `model` since `copy_attr` somehow does not copy it. - self.model = model.model - - def forward(self, x): - """ - Forward pass through the model. - - This method performs the forward pass through the model, handling the dependencies between layers and saving - intermediate outputs. - - Args: - x (torch.Tensor): The input tensor to the model. - - Returns: - (torch.Tensor): The output tensor from the model. - """ - y = [] # outputs - for m in self.model: - if m.f != -1: # if not from previous layer - # from earlier layers - x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] - x = m(x) # run - y.append(x) # save output - return x diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/triton.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/triton.py deleted file mode 100644 index 2de830e..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/triton.py +++ /dev/null @@ -1,117 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from typing import List -from urllib.parse import urlsplit - -import numpy as np - - -class TritonRemoteModel: - """ - Client for interacting with a remote Triton Inference Server model. - - This class provides a convenient interface for sending inference requests to a Triton Inference Server - and processing the responses. Supports both HTTP and gRPC communication protocols. - - Attributes: - endpoint (str): The name of the model on the Triton server. - url (str): The URL of the Triton server. - triton_client: The Triton client (either HTTP or gRPC). - InferInput: The input class for the Triton client. - InferRequestedOutput: The output request class for the Triton client. - input_formats (List[str]): The data types of the model inputs. - np_input_formats (List[type]): The numpy data types of the model inputs. - input_names (List[str]): The names of the model inputs. - output_names (List[str]): The names of the model outputs. - metadata: The metadata associated with the model. - - Methods: - __call__: Call the model with the given inputs and return the outputs. - - Examples: - Initialize a Triton client with HTTP - >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http") - - Make inference with numpy arrays - >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32)) - """ - - def __init__(self, url: str, endpoint: str = "", scheme: str = ""): - """ - Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server. - - Arguments may be provided individually or parsed from a collective 'url' argument of the form - ://// - - Args: - url (str): The URL of the Triton server. - endpoint (str, optional): The name of the model on the Triton server. - scheme (str, optional): The communication scheme ('http' or 'grpc'). - - Examples: - >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http") - >>> model = TritonRemoteModel(url="http://localhost:8000/yolov8") - """ - if not endpoint and not scheme: # Parse all args from URL string - splits = urlsplit(url) - endpoint = splits.path.strip("/").split("/", 1)[0] - scheme = splits.scheme - url = splits.netloc - - self.endpoint = endpoint - self.url = url - - # Choose the Triton client based on the communication scheme - if scheme == "http": - import tritonclient.http as client # noqa - - self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) - config = self.triton_client.get_model_config(endpoint) - else: - import tritonclient.grpc as client # noqa - - self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False) - config = self.triton_client.get_model_config(endpoint, as_json=True)["config"] - - # Sort output names alphabetically, i.e. 'output0', 'output1', etc. - config["output"] = sorted(config["output"], key=lambda x: x.get("name")) - - # Define model attributes - type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8} - self.InferRequestedOutput = client.InferRequestedOutput - self.InferInput = client.InferInput - self.input_formats = [x["data_type"] for x in config["input"]] - self.np_input_formats = [type_map[x] for x in self.input_formats] - self.input_names = [x["name"] for x in config["input"]] - self.output_names = [x["name"] for x in config["output"]] - self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None")) - - def __call__(self, *inputs: np.ndarray) -> List[np.ndarray]: - """ - Call the model with the given inputs and return inference results. - - Args: - *inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type - for the corresponding model input. - - Returns: - (List[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list - corresponds to one of the model's output tensors. - - Examples: - >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http") - >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32)) - """ - infer_inputs = [] - input_format = inputs[0].dtype - for i, x in enumerate(inputs): - if x.dtype != self.np_input_formats[i]: - x = x.astype(self.np_input_formats[i]) - infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", "")) - infer_input.set_data_from_numpy(x) - infer_inputs.append(infer_input) - - infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names] - outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs) - - return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names] diff --git a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/tuner.py b/hertz_studio_django_utils/yolo/Train/ultralytics/utils/tuner.py deleted file mode 100644 index defaacd..0000000 --- a/hertz_studio_django_utils/yolo/Train/ultralytics/utils/tuner.py +++ /dev/null @@ -1,159 +0,0 @@ -# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license - -from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir -from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr - - -def run_ray_tune( - model, - space: dict = None, - grace_period: int = 10, - gpu_per_trial: int = None, - max_samples: int = 10, - **train_args, -): - """ - Run hyperparameter tuning using Ray Tune. - - Args: - model (YOLO): Model to run the tuner on. - space (dict, optional): The hyperparameter search space. If not provided, uses default space. - grace_period (int, optional): The grace period in epochs of the ASHA scheduler. - gpu_per_trial (int, optional): The number of GPUs to allocate per trial. - max_samples (int, optional): The maximum number of trials to run. - **train_args (Any): Additional arguments to pass to the `train()` method. - - Returns: - (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search. - - Examples: - >>> from ultralytics import YOLO - >>> model = YOLO("yolo11n.pt") # Load a YOLO11n model - - Start tuning hyperparameters for YOLO11n training on the COCO8 dataset - >>> result_grid = model.tune(data="coco8.yaml", use_ray=True) - """ - LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune") - if train_args is None: - train_args = {} - - try: - checks.check_requirements("ray[tune]") - - import ray - from ray import tune - from ray.air import RunConfig - from ray.air.integrations.wandb import WandbLoggerCallback - from ray.tune.schedulers import ASHAScheduler - except ImportError: - raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"') - - try: - import wandb - - assert hasattr(wandb, "__version__") - except (ImportError, AssertionError): - wandb = False - - checks.check_version(ray.__version__, ">=2.0.0", "ray") - default_space = { - # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']), - "lr0": tune.uniform(1e-5, 1e-1), - "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) - "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 - "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay - "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) - "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum - "box": tune.uniform(0.02, 0.2), # box loss gain - "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) - "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) - "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) - "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) - "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg) - "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction) - "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain) - "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg) - "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 - "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability) - "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability) - "bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability) - "mosaic": tune.uniform(0.0, 1.0), # image mosaic (probability) - "mixup": tune.uniform(0.0, 1.0), # image mixup (probability) - "cutmix": tune.uniform(0.0, 1.0), # image cutmix (probability) - "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability) - } - - # Put the model in ray store - task = model.task - model_in_store = ray.put(model) - - def _tune(config): - """Train the YOLO model with the specified hyperparameters and return results.""" - model_to_train = ray.get(model_in_store) # get the model from ray store for tuning - model_to_train.reset_callbacks() - config.update(train_args) - results = model_to_train.train(**config) - return results.results_dict - - # Get search space - if not space and not train_args.get("resume"): - space = default_space - LOGGER.warning("Search space not provided, using default search space.") - - # Get dataset - data = train_args.get("data", TASK2DATA[task]) - space["data"] = data - if "data" not in train_args: - LOGGER.warning(f'Data not provided, using default "data={data}".') - - # Define the trainable function with allocated resources - trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0}) - - # Define the ASHA scheduler for hyperparameter search - asha_scheduler = ASHAScheduler( - time_attr="epoch", - metric=TASK2METRIC[task], - mode="max", - max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100, - grace_period=grace_period, - reduction_factor=3, - ) - - # Define the callbacks for the hyperparameter search - tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else [] - - # Create the Ray Tune hyperparameter search tuner - tune_dir = get_save_dir( - get_cfg( - DEFAULT_CFG, - {**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir - ), - name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir} - ).resolve() # must be absolute dir - tune_dir.mkdir(parents=True, exist_ok=True) - if tune.Tuner.can_restore(tune_dir): - LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...") - tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True) - else: - tuner = tune.Tuner( - trainable_with_resources, - param_space=space, - tune_config=tune.TuneConfig( - scheduler=asha_scheduler, - num_samples=max_samples, - trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}", - trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}", - ), - run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name), - ) - - # Run the hyperparameter search - tuner.fit() - - # Get the results of the hyperparameter search - results = tuner.get_results() - - # Shut down Ray to clean up workers - ray.shutdown() - - return results diff --git a/hertz_studio_django_utils/yolo/__init__.py b/hertz_studio_django_utils/yolo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/hertz_studio_django_utils/yolo/convert_paths_to_relative.py b/hertz_studio_django_utils/yolo/convert_paths_to_relative.py deleted file mode 100644 index 6da0d05..0000000 --- a/hertz_studio_django_utils/yolo/convert_paths_to_relative.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -将数据库中的绝对路径转换为相对路径的脚本 -执行方式: python convert_paths_to_relative.py -""" -import os -import sys -import django - -# 设置Django环境 -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') -django.setup() - -from django.conf import settings -from hertz_studio_django_yolo.models import YoloModel, Dataset - - -def convert_to_relative_path(absolute_path): - """将绝对路径转换为相对于MEDIA_ROOT的相对路径""" - if not absolute_path: - return None - - # 如果已经是相对路径,直接返回 - if not os.path.isabs(absolute_path): - return absolute_path - - try: - # 计算相对路径 - relative_path = os.path.relpath(absolute_path, settings.MEDIA_ROOT) - return relative_path - except ValueError: - # 如果路径不在MEDIA_ROOT下,返回原路径 - print(f"警告: 路径 {absolute_path} 不在 MEDIA_ROOT 下") - return absolute_path - - -def convert_yolo_models(): - """转换YoloModel中的路径""" - print("开始转换 YoloModel 表中的路径...") - - models = YoloModel.objects.all() - updated_count = 0 - - for model in models: - updated = False - - # 转换 model_folder_path - if model.model_folder_path and os.path.isabs(model.model_folder_path): - old_path = model.model_folder_path - model.model_folder_path = convert_to_relative_path(old_path) - print(f" 模型 {model.name}: model_folder_path") - print(f" 原路径: {old_path}") - print(f" 新路径: {model.model_folder_path}") - updated = True - - # 转换 best_model_path - if model.best_model_path and os.path.isabs(model.best_model_path): - old_path = model.best_model_path - model.best_model_path = convert_to_relative_path(old_path) - print(f" 模型 {model.name}: best_model_path") - print(f" 原路径: {old_path}") - print(f" 新路径: {model.best_model_path}") - updated = True - - # 转换 last_model_path - if model.last_model_path and os.path.isabs(model.last_model_path): - old_path = model.last_model_path - model.last_model_path = convert_to_relative_path(old_path) - print(f" 模型 {model.name}: last_model_path") - print(f" 原路径: {old_path}") - print(f" 新路径: {model.last_model_path}") - updated = True - - if updated: - model.save() - updated_count += 1 - - print(f"YoloModel 转换完成! 更新了 {updated_count} 个模型记录\n") - - -def convert_datasets(): - """转换Dataset中的路径""" - print("开始转换 Dataset 表中的路径...") - - datasets = Dataset.objects.all() - updated_count = 0 - - for dataset in datasets: - updated = False - - # 转换 root_folder_path - if dataset.root_folder_path and os.path.isabs(dataset.root_folder_path): - old_path = dataset.root_folder_path - dataset.root_folder_path = convert_to_relative_path(old_path) - print(f" 数据集 {dataset.name}: root_folder_path") - print(f" 原路径: {old_path}") - print(f" 新路径: {dataset.root_folder_path}") - updated = True - - # 转换 data_yaml_path - if dataset.data_yaml_path and os.path.isabs(dataset.data_yaml_path): - old_path = dataset.data_yaml_path - dataset.data_yaml_path = convert_to_relative_path(old_path) - print(f" 数据集 {dataset.name}: data_yaml_path") - print(f" 原路径: {old_path}") - print(f" 新路径: {dataset.data_yaml_path}") - updated = True - - if updated: - dataset.save() - updated_count += 1 - - print(f"Dataset 转换完成! 更新了 {updated_count} 个数据集记录\n") - - -def main(): - print("=" * 80) - print("路径转换工具 - 将绝对路径转换为相对路径") - print("=" * 80) - print(f"MEDIA_ROOT: {settings.MEDIA_ROOT}\n") - - # 确认执行 - confirm = input("是否开始转换? (输入 yes 确认): ") - if confirm.lower() != 'yes': - print("操作已取消") - return - - try: - # 转换YoloModel - convert_yolo_models() - - # 转换Dataset - convert_datasets() - - print("=" * 80) - print("所有路径转换完成!") - print("=" * 80) - - except Exception as e: - print(f"错误: {e}") - import traceback - traceback.print_exc() - - -if __name__ == '__main__': - main() diff --git a/hertz_studio_django_utils/yolo/video_converter.py b/hertz_studio_django_utils/yolo/video_converter.py deleted file mode 100644 index b120c56..0000000 --- a/hertz_studio_django_utils/yolo/video_converter.py +++ /dev/null @@ -1,244 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -视频转换工具类 -用于将视频转换为H.264编码的MP4格式,确保浏览器兼容性 -""" - -import os -import subprocess -import json -import logging -from pathlib import Path -from typing import Optional, Dict, Any - -logger = logging.getLogger(__name__) - -class VideoConverter: - """视频格式转换工具类""" - - def __init__(self): - """初始化视频转换器""" - self.ffmpeg_available = self._check_ffmpeg() - - def _check_ffmpeg(self) -> bool: - """ - 检查FFmpeg是否可用 - - Returns: - bool: FFmpeg是否可用 - """ - try: - result = subprocess.run( - ['ffmpeg', '-version'], - capture_output=True, - text=True, - timeout=10 - ) - return result.returncode == 0 - except (FileNotFoundError, subprocess.TimeoutExpired): - logger.warning("FFmpeg未安装或不可用") - return False - - def get_video_info(self, video_path: str) -> Optional[Dict[str, Any]]: - """ - 获取视频信息 - - Args: - video_path: 视频文件路径 - - Returns: - Dict: 视频信息字典,包含编码格式、分辨率、时长等 - """ - if not self.ffmpeg_available: - logger.error("FFmpeg不可用,无法获取视频信息") - return None - - try: - cmd = [ - 'ffprobe', '-v', 'quiet', '-print_format', 'json', - '-show_format', '-show_streams', video_path - ] - - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=30 - ) - - if result.returncode != 0: - logger.error(f"获取视频信息失败: {result.stderr}") - return None - - info = json.loads(result.stdout) - - # 查找视频流 - video_stream = None - for stream in info.get('streams', []): - if stream.get('codec_type') == 'video': - video_stream = stream - break - - if not video_stream: - logger.error("未找到视频流") - return None - - return { - 'codec': video_stream.get('codec_name', 'unknown'), - 'width': video_stream.get('width', 0), - 'height': video_stream.get('height', 0), - 'duration': float(info.get('format', {}).get('duration', 0)), - 'size': os.path.getsize(video_path) if os.path.exists(video_path) else 0 - } - - except Exception as e: - logger.error(f"获取视频信息时出错: {str(e)}") - return None - - def is_h264_compatible(self, video_path: str) -> bool: - """ - 检查视频是否已经是H.264编码 - - Args: - video_path: 视频文件路径 - - Returns: - bool: 是否为H.264编码 - """ - video_info = self.get_video_info(video_path) - if not video_info: - return False - - return video_info.get('codec', '').lower() == 'h264' - - def convert_to_h264(self, input_path: str, output_path: Optional[str] = None, - quality: str = 'medium', overwrite: bool = True) -> Optional[str]: - """ - 将视频转换为H.264编码的MP4格式 - - Args: - input_path: 输入视频文件路径 - output_path: 输出视频文件路径(可选) - quality: 质量设置 ('high', 'medium', 'low') - overwrite: 是否覆盖已存在的文件 - - Returns: - str: 转换后的文件路径,失败返回None - """ - if not self.ffmpeg_available: - logger.error("FFmpeg不可用,无法进行视频转换") - return None - - input_path = Path(input_path) - - if not input_path.exists(): - logger.error(f"输入文件不存在: {input_path}") - return None - - # 检查是否已经是H.264格式 - if self.is_h264_compatible(str(input_path)): - logger.info(f"视频已经是H.264格式: {input_path}") - return str(input_path) - - # 生成输出文件路径 - if output_path is None: - output_path = input_path.parent / f"{input_path.stem}_h264.mp4" - else: - output_path = Path(output_path) - - # 检查输出文件是否已存在 - if output_path.exists() and not overwrite: - logger.info(f"输出文件已存在: {output_path}") - return str(output_path) - - # 设置质量参数 - quality_settings = { - 'high': {'crf': '18', 'preset': 'slow'}, - 'medium': {'crf': '23', 'preset': 'medium'}, - 'low': {'crf': '28', 'preset': 'fast'} - } - - settings = quality_settings.get(quality, quality_settings['medium']) - - # 构建FFmpeg命令 - cmd = [ - 'ffmpeg', - '-i', str(input_path), - '-c:v', 'libx264', # 使用H.264编码器 - '-crf', settings['crf'], # 质量设置 - '-preset', settings['preset'], # 编码速度预设 - '-c:a', 'aac', # 音频编码器 - '-b:a', '128k', # 音频比特率 - '-movflags', '+faststart', # 优化网络播放 - '-y' if overwrite else '-n', # 覆盖或跳过已存在文件 - str(output_path) - ] - - try: - logger.info(f"开始转换视频: {input_path} -> {output_path}") - - # 执行转换 - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=300 # 5分钟超时 - ) - - if result.returncode == 0: - if output_path.exists(): - logger.info(f"视频转换成功: {output_path}") - - # 验证转换结果 - if self.is_h264_compatible(str(output_path)): - return str(output_path) - else: - logger.error("转换完成但格式验证失败") - return None - else: - logger.error("转换完成但输出文件未生成") - return None - else: - logger.error(f"视频转换失败: {result.stderr}") - return None - - except subprocess.TimeoutExpired: - logger.error("视频转换超时") - return None - except Exception as e: - logger.error(f"视频转换过程中出错: {str(e)}") - return None - - def ensure_h264_format(self, video_path: str, quality: str = 'medium') -> str: - """ - 确保视频为H.264格式,如果不是则自动转换 - - Args: - video_path: 视频文件路径 - quality: 转换质量设置 - - Returns: - str: H.264格式的视频文件路径 - """ - if self.is_h264_compatible(video_path): - return video_path - - converted_path = self.convert_to_h264(video_path, quality=quality) - return converted_path if converted_path else video_path - - def get_conversion_status(self) -> Dict[str, Any]: - """ - 获取转换器状态信息 - - Returns: - Dict: 状态信息 - """ - return { - 'ffmpeg_available': self.ffmpeg_available, - 'supported_formats': ['mp4', 'avi', 'mov', 'mkv', 'wmv', 'flv'] if self.ffmpeg_available else [], - 'output_format': 'H.264 MP4' - } - -# 创建全局实例 -video_converter = VideoConverter() \ No newline at end of file diff --git a/init_backend.bat b/init_backend.bat deleted file mode 100644 index eff5f7a..0000000 --- a/init_backend.bat +++ /dev/null @@ -1,62 +0,0 @@ -@echo off -chcp 65001 >nul -echo ================================ -echo Hertz Django Project Initialization Script -echo ================================ - -echo Checking Python environment... -python --version -if %errorlevel% neq 0 ( - echo Error: Python not detected, please install Python first! - pause - exit /b 1 -) - -echo Configuring pip global mirror... -pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple - -echo Creating virtual environment... -python -m venv venv -if %errorlevel% neq 0 ( - echo Error: Failed to create virtual environment! - pause - exit /b 1 -) - -echo Activating virtual environment... -call venv\Scripts\activate -if %errorlevel% neq 0 ( - echo Error: Failed to activate virtual environment! - pause - exit /b 1 -) - -echo Upgrading pip... -python -m pip install --upgrade pip -if %errorlevel% neq 0 ( - echo Error: Failed to upgrade pip! - pause - exit /b 1 -) - -echo Installing Python third-party dependencies... -pip install -r requirements.txt -if %errorlevel% neq 0 ( - echo Error: Failed to install requirements.txt! - pause - exit /b 1 -) - -echo Installing Hertz official dependencies... -pip install -r hertz.txt -i https://hertz:hertz@hzpypi.hzsystems.cn/simple/ -if %errorlevel% neq 0 ( - echo Error: Failed to install hertz.txt! Please activate the machine code first. - pause - exit /b 1 -) - -echo ================================ -echo Project initialization completed! -echo ================================ -echo Please run start_project.bat to start the project -pause \ No newline at end of file diff --git a/init_frontend.bat b/init_frontend.bat deleted file mode 100644 index b43526d..0000000 --- a/init_frontend.bat +++ /dev/null @@ -1,29 +0,0 @@ -@echo off -echo Initializing Hertz Django UI Project... -echo. - -node --version >nul 2>&1 -if %errorlevel% neq 0 ( - echo ERROR: Node.js not found - pause - exit /b 1 -) - -if not exist "hertz_server_django_ui" ( - echo ERROR: Project directory missing - pause - exit /b 1 -) - -cd hertz_server_django_ui -echo Installing dependencies... -npm install -if %errorlevel% neq 0 ( - echo ERROR: Dependency installation failed - pause - exit /b 1 -) - -echo SUCCESS: Project initialized successfully -echo Run "start.bat" to start the development server -pause \ No newline at end of file diff --git a/launch_backend.bat b/launch_backend.bat deleted file mode 100644 index d4abfc4..0000000 --- a/launch_backend.bat +++ /dev/null @@ -1,22 +0,0 @@ -@echo off -echo ================================ -echo Hertz Django -echo ================================ - -echo activate venv -call venv\Scripts\activate -if %errorlevel% neq 0 ( - echo Please run: init_project.bat - pause - exit /b 1 -) - - -start "Hertz Backend" /D "%cd%" cmd /c "python start_server.py" - - -echo ================================ -echo Hertz Django Starting Successful -echo Please wait for the server to start... -echo ================================ -pause >nul \ No newline at end of file diff --git a/launch_frontend.bat b/launch_frontend.bat deleted file mode 100644 index 7fd97a0..0000000 --- a/launch_frontend.bat +++ /dev/null @@ -1,14 +0,0 @@ -@echo off -echo Starting Hertz Django UI Development Server... -echo. - -if not exist "hertz_server_django_ui" ( - echo ERROR: Project directory not found - pause - exit /b 1 -) - -cd hertz_server_django_ui -echo Server starting... Press Ctrl+C to stop -echo. -npm run dev \ No newline at end of file diff --git a/manage.py b/manage.py deleted file mode 100644 index c6e8197..0000000 --- a/manage.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python -"""Django's command-line utility for administrative tasks.""" -import os -import sys - - -def main(): - """Run administrative tasks.""" - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') - try: - from django.core.management import execute_from_command_line - except ImportError as exc: - raise ImportError( - "Couldn't import Django. Are you sure it's installed and " - "available on your PYTHONPATH environment variable? Did you " - "forget to activate a virtual environment?" - ) from exc - execute_from_command_line(sys.argv) - - -if __name__ == '__main__': - main() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index db2bcf4..0000000 --- a/requirements.txt +++ /dev/null @@ -1,25 +0,0 @@ -# ================python官方库================ -Django==5.1.2 -channels==4.0.0 -channels-redis==4.2.0 -djangorestframework==3.14.0 -drf-spectacular==0.27.0 -django-cors-headers==4.3.1 -redis==5.0.1 -mysqlclient>=2.2.7 -python-decouple==3.8 -django-redis==5.4.0 -daphne==4.0.0 -watchdog==3.0.0 -pillow>=10.0.1 -pyjwt>=2.10.1 -psutil>=5.9.0 -GPUtil>=1.4.0 -Mako>=1.3.0 -ultralytics>=8.0.0 -opencv-python>=4.8.0 -numpy>=1.24.0 -scikit-learn>=1.1.0 -joblib>=1.2.0 -aiohttp>=3.13.2 -requests>=2.32.3 \ No newline at end of file diff --git a/start_server.py b/start_server.py deleted file mode 100644 index 49b71b0..0000000 --- a/start_server.py +++ /dev/null @@ -1,1299 +0,0 @@ -#!/usr/bin/env python -""" -启动脚本 - 同时支持HTTP和WebSocket -使用Daphne ASGI服务器启动Django应用,支持自动热重启 -包含数据库初始化功能和菜单权限同步功能 -""" - -import os -import sys -import subprocess -import time -import threading -import django -import importlib.util -import re -import argparse -from pathlib import Path -from django.db import transaction -from django.db import models -from django.core.management import call_command -from hertz_studio_django_utils.config.menus_config import menus, add_new_menus -from hertz_studio_django_utils.config.departments_config import departments -from hertz_studio_django_utils.config.roles_config import roles - - - -def register_app_in_settings(settings_path: str, app_name: str) -> bool: - """ - 在settings.py中注册新应用 - - Args: - settings_path: settings.py文件路径 - app_name: 应用名称 - - Returns: - bool: 注册是否成功 - """ - try: - # 读取settings.py文件 - with open(settings_path, 'r', encoding='utf-8') as f: - content = f.read() - - # 检查应用是否已经注册 - if f"'{app_name}'" in content: - print(f"应用 {app_name} 已在settings.py中注册") - return True - - # 使用更精确的正则表达式匹配INSTALLED_APPS - # 匹配从INSTALLED_APPS = [开始到对应的]结束 - pattern = r"INSTALLED_APPS\s*=\s*\[(.*?)\n\]" - match = re.search(pattern, content, re.DOTALL) - - if not match: - print("❌ 未找到INSTALLED_APPS配置") - return False - - # 获取INSTALLED_APPS的内容 - apps_content = match.group(1) - - # 在最后一个应用后添加新应用 - # 找到最后一个应用的位置(以逗号结尾的行) - lines = apps_content.split('\n') - - # 找到最后一个非空行的位置 - last_app_index = -1 - for i in range(len(lines) - 1, -1, -1): - line = lines[i].strip() - if line and not line.startswith('#') and line.endswith(','): - last_app_index = i - break - - if last_app_index >= 0: - # 在最后一个应用后添加新应用 - lines.insert(last_app_index + 1, f" '{app_name}', # 自动注册的应用") - else: - # 如果没有找到合适的位置,在最后添加 - lines.append(f" '{app_name}', # 自动注册的应用") - - # 重新组装内容 - new_apps_content = '\n'.join(lines) - new_content = content.replace( - f"INSTALLED_APPS = [{apps_content}\n]", - f"INSTALLED_APPS = [{new_apps_content}\n]" - ) - - # 写回文件 - with open(settings_path, 'w', encoding='utf-8') as f: - f.write(new_content) - - print(f"✅ 应用 {app_name} 已注册到settings.py") - return True - - except Exception as e: - print(f"❌ 注册应用到settings.py失败: {e}") - return False - - -def register_urls_in_project(urls_path: str, app_name: str) -> bool: - """ - 在项目urls.py中注册新应用的URL路由 - - Args: - urls_path: 项目urls.py文件路径 - app_name: 应用名称 - - Returns: - bool: 注册是否成功 - """ - try: - # 读取urls.py文件 - with open(urls_path, 'r', encoding='utf-8') as f: - content = f.read() - - # 检查URL是否已经注册 - if f"include('{app_name}.urls')" in content: - print(f"应用 {app_name} 的URL已在项目urls.py中注册") - return True - - # 查找urlpatterns的位置 - pattern = r"(urlpatterns\s*=\s*\[)(.*?)(\n\])" - match = re.search(pattern, content, re.DOTALL) - - if not match: - print("❌ 未找到urlpatterns配置") - return False - - # 获取urlpatterns的内容 - patterns_start = match.group(1) - patterns_content = match.group(2) - patterns_end = match.group(3) - - # 生成URL路由配置 - # 根据应用名称生成合适的URL前缀 - if app_name.startswith('hertz_studio_django_'): - # 提取模块名作为URL前缀 - module_name = app_name.replace('hertz_studio_django_', '') - url_prefix = f"api/{module_name}/" - comment = f"# Hertz {module_name.title()} routes" - else: - url_prefix = f"api/{app_name}/" - comment = f"# {app_name.title()} routes" - - new_route = f"\n {comment}\n path('{url_prefix}', include('{app_name}.urls'))," - - # 在API documentation routes之前添加新路由 - if "# API documentation routes" in patterns_content: - new_patterns_content = patterns_content.replace( - " # API documentation routes", - f" {new_route}\n \n # API documentation routes" - ) - else: - # 如果没有找到API documentation routes,在最后添加 - new_patterns_content = patterns_content.rstrip() + new_route + "\n" - - # 重新组装内容 - new_content = content.replace( - patterns_start + patterns_content + patterns_end, - patterns_start + new_patterns_content + patterns_end - ) - - # 写回文件 - with open(urls_path, 'w', encoding='utf-8') as f: - f.write(new_content) - - print(f"✅ 应用 {app_name} 的URL已注册到项目urls.py") - return True - - except Exception as e: - print(f"❌ 注册URL到项目urls.py失败: {e}") - return False - - -def scan_and_register_new_apps() -> list: - """ - 扫描项目目录,发现并注册新的Django应用 - - Returns: - list: 新注册的应用列表 - """ - print("🔍 扫描项目目录,查找新的Django应用...") - - project_root = Path(__file__).parent - settings_path = project_root / 'hertz_server_django' / 'settings.py' - urls_path = project_root / 'hertz_server_django' / 'urls.py' - - # 读取当前已注册的应用 - registered_apps = set() - try: - spec = importlib.util.spec_from_file_location("settings", settings_path) - settings_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(settings_module) - registered_apps = set(settings_module.INSTALLED_APPS) - except Exception as e: - print(f"❌ 读取settings.py失败: {e}") - return [] - - # 扫描项目目录,查找Django应用 - new_apps = [] - for item in project_root.iterdir(): - if item.is_dir() and item.name.startswith('hertz_studio_django_'): - app_name = item.name - - # 检查是否是Django应用(包含apps.py文件) - apps_py = item / 'apps.py' - if apps_py.exists() and app_name not in registered_apps: - print(f"🆕 发现新应用: {app_name}") - - # 1. 注册到settings.py - if register_app_in_settings(str(settings_path), app_name): - # 2. 注册到urls.py - if register_urls_in_project(str(urls_path), app_name): - new_apps.append(app_name) - print(f"✅ 应用 {app_name} 注册成功") - else: - print(f"❌ 应用 {app_name} URL注册失败") - else: - print(f"❌ 应用 {app_name} settings注册失败") - - if new_apps: - print(f"🎉 成功注册 {len(new_apps)} 个新应用: {', '.join(new_apps)}") - else: - print("✅ 没有发现新的Django应用") - - return new_apps - - -def execute_migrations_for_new_apps(new_apps: list) -> bool: - """ - 为新注册的应用执行数据库迁移 - - Args: - new_apps: 新应用列表 - - Returns: - bool: 迁移是否成功 - """ - if not new_apps: - return True - - try: - print(f"📋 为新应用执行数据库迁移: {', '.join(new_apps)}") - - for app_name in new_apps: - print(f"📝 为应用 {app_name} 生成迁移文件...") - try: - # 先检查应用是否在Django中正确加载 - from django.apps import apps - try: - app_config = apps.get_app_config(app_name) - print(f"✅ 应用 {app_name} 已正确加载") - except LookupError: - print(f"⚠️ 应用 {app_name} 未在Django中加载,跳过迁移") - continue - - # 生成迁移文件 - call_command('makemigrations', app_name, verbosity=1) - print(f"✅ 应用 {app_name} 迁移文件生成成功") - - # 执行迁移 - call_command('migrate', app_name, verbosity=1) - print(f"✅ 应用 {app_name} 迁移执行成功") - - except Exception as e: - print(f"⚠️ 应用 {app_name} 迁移失败: {e}") - continue - - return True - - except Exception as e: - print(f"❌ 执行新应用迁移失败: {e}") - return False -def init_superuser(): - """ - 初始化超级管理员账号 - """ - from hertz_studio_django_auth.models import HertzUser - - print("正在初始化超级管理员账号...") - - # 检查是否已存在超级管理员 - if HertzUser.objects.filter(username='hertz').exists(): - print("超级管理员账号已存在,跳过创建") - return HertzUser.objects.get(username='hertz') - - # 创建超级管理员 - superuser = HertzUser.objects.create_superuser( - username='hertz', - email='admin@hertz.com', - password='hertz', - real_name='超级管理员', - status=1 - ) - - print(f"超级管理员账号创建成功: {superuser.username}") - return superuser - -def init_demo_user(): - from hertz_studio_django_auth.models import HertzUser, HertzUserRole, HertzRole - print("正在初始化普通用户账号...") - if HertzUser.objects.filter(username='demo').exists(): - print("普通用户账号已存在,跳过创建") - user = HertzUser.objects.get(username='demo') - else: - user = HertzUser.objects.create_user( - username='demo', - email='demo@hertz.com', - password='123456', - real_name='普通用户', - status=1 - ) - print(f"普通用户账号创建成功: {user.username}") - try: - role = HertzRole.objects.get(role_id=3) - user_role, created = HertzUserRole.objects.get_or_create(user=user, role=role) - if created: - print(f"为用户 {user.username} 分配角色ID: {role.role_id}") - else: - print(f"用户 {user.username} 已拥有角色ID: {role.role_id}") - except HertzRole.DoesNotExist: - print("角色ID=3不存在,跳过分配") - return user - - -def init_departments(): - """ - 初始化部门数据 - """ - from hertz_studio_django_auth.models import HertzDepartment - - print("正在初始化部门数据...") - - - created_depts = {} - - for dept_data in departments: - parent_code = dept_data.pop('parent_code', None) - parent_id = None - - if parent_code and parent_code in created_depts: - parent_id = created_depts[parent_code] - - dept, created = HertzDepartment.objects.get_or_create( - dept_code=dept_data['dept_code'], - defaults={ - **dept_data, - 'parent_id': parent_id - } - ) - - created_depts[dept.dept_code] = dept - - if created: - print(f"部门创建成功: {dept.dept_name}") - else: - print(f"部门已存在: {dept.dept_name}") - - return created_depts - - -def init_menus(): - """ - 初始化菜单数据 - """ - from hertz_studio_django_auth.models import HertzMenu - - print("正在初始化菜单数据...") - - # 菜单数据结构 - created_menus = {} - - # 按层级创建菜单 - for menu_data in menus: - parent_code = menu_data.pop('parent_code', None) - parent_id = None - - if parent_code and parent_code in created_menus: - parent_id = created_menus[parent_code] - - menu, created = HertzMenu.objects.get_or_create( - menu_code=menu_data['menu_code'], - defaults={ - **menu_data, - 'parent_id': parent_id - } - ) - - created_menus[menu.menu_code] = menu - - if created: - print(f"菜单创建成功: {menu.menu_name}") - else: - print(f"菜单已存在: {menu.menu_name}") - - return created_menus - - -def init_roles(): - """ - 初始化角色数据 - """ - from hertz_studio_django_auth.models import HertzRole - - print("正在初始化角色数据...") - - - - created_roles = {} - - for role_data in roles: - role, created = HertzRole.objects.get_or_create( - role_code=role_data['role_code'], - defaults=role_data - ) - - created_roles[role.role_code] = role - - if created: - print(f"角色创建成功: {role.role_name}") - else: - print(f"角色已存在: {role.role_name}") - - return created_roles - - -def assign_role_menus(roles, menus): - """ - 分配角色菜单权限 - """ - from hertz_studio_django_auth.models import HertzRoleMenu - - print("正在分配角色菜单权限...") - - # 超级管理员拥有所有权限 - super_admin_role = roles['super_admin'] - - for menu in menus.values(): - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=super_admin_role, - menu=menu - ) - - if created: - print(f"为超级管理员分配权限: {menu.menu_name}") - - # 系统管理员拥有系统管理权限和工作室权限 - system_admin_role = roles['system_admin'] - - # 系统管理权限(包括日志管理和知识管理) - system_menus = [menu for menu in menus.values() if menu.menu_code.startswith('system')] - for menu in system_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=system_admin_role, - menu=menu - ) - - if created: - print(f"为系统管理员分配系统权限: {menu.menu_name}") - - # 确保系统管理员拥有知识管理权限 - wiki_menus = [menu for menu in menus.values() if 'wiki' in menu.menu_code.lower()] - for menu in wiki_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=system_admin_role, - menu=menu - ) - - if created: - print(f"为系统管理员分配知识管理权限: {menu.menu_name}") - - # 确保系统管理员拥有日志管理权限 - log_menus = [menu for menu in menus.values() if 'log' in menu.menu_code.lower()] - for menu in log_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=system_admin_role, - menu=menu - ) - - if created: - print(f"为系统管理员分配日志权限: {menu.menu_name}") - - # 确保超级管理员也拥有所有日志权限(包括动态创建的子菜单) - from hertz_studio_django_auth.models import HertzMenu - all_log_menus = HertzMenu.objects.filter(menu_code__icontains='log', status=1) - for menu in all_log_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=super_admin_role, - menu=menu - ) - - if created: - print(f"为超级管理员分配日志权限: {menu.menu_name}") - - # 确保超级管理员也拥有所有知识管理权限(包括动态创建的子菜单) - all_wiki_menus = HertzMenu.objects.filter(menu_code__icontains='wiki', status=1) - for menu in all_wiki_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=super_admin_role, - menu=menu - ) - - if created: - print(f"为超级管理员分配知识管理权限: {menu.menu_name}") - - # 工作室权限(包括通知公告、AI对话、系统监控) - studio_menus = [menu for menu in menus.values() if menu.menu_code.startswith('studio')] - for menu in studio_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=system_admin_role, - menu=menu - ) - - if created: - print(f"为系统管理员分配工作室权限: {menu.menu_name}") - - # 普通用户拥有工作室基础权限(查询、列表、新增权限) - normal_user_role = roles.get('normal_user') - if normal_user_role: - # 工作室目录权限 - studio_directory = [menu for menu in menus.values() if menu.menu_code == 'studio'] - for menu in studio_directory: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=normal_user_role, - menu=menu - ) - - if created: - print(f"为普通用户分配工作室目录权限: {menu.menu_name}") - - # 工作室各模块的基础权限(查询、列表、新增,排除编辑和删除) - user_studio_menus = [ - menu for menu in menus.values() - if menu.menu_code in [ - # 通知公告模块 - 'studio:notice', 'studio:notice:query', 'studio:notice:add', - # AI对话模块 - 包含关键的list权限 - 'studio:ai', 'studio:ai:query', 'studio:ai:add', 'studio:ai:list', - # 系统监控模块 - 'studio:system_monitor', 'studio:system_monitor:query', 'studio:system_monitor:add' - ] - ] - for menu in user_studio_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=normal_user_role, - menu=menu - ) - - if created: - print(f"为普通用户分配工作室权限: {menu.menu_name}") - - # 为普通用户分配知识库权限(查询、列表、新增权限) - user_wiki_menus = [ - menu for menu in menus.values() - if menu.menu_code in [ - # 知识管理主菜单 - 'system:wiki', - # 知识分类权限 - 'system:wiki:category', 'system:wiki:category:list', - 'system:wiki:category:query', 'system:wiki:category:create', - # 知识文章权限 - 'system:wiki:article', 'system:wiki:article:list', - 'system:wiki:article:query', 'system:wiki:article:create' - ] - ] - for menu in user_wiki_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=normal_user_role, - menu=menu - ) - - if created: - print(f"为普通用户分配知识库权限: {menu.menu_name}") - - # 确保普通用户也拥有所有知识管理权限(包括动态创建的子菜单) - all_wiki_menus = HertzMenu.objects.filter(menu_code__icontains='wiki', status=1) - for menu in all_wiki_menus: - # 只给普通用户分配查询和列表权限,不包括删除、修改和编辑权限 - if not any(perm in menu.menu_code for perm in ['remove', 'delete', 'edit', 'update']): - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=normal_user_role, - menu=menu - ) - - if created: - print(f"为普通用户分配知识管理权限: {menu.menu_name}") - - # 为超级管理员和系统管理员分配产品管理权限(包括动态创建的产品菜单) - # 只查询小写的product菜单(正确的权限格式) - all_product_menus = HertzMenu.objects.filter( - menu_code__icontains='product', - status=1 - ) - - # 为超级管理员分配产品管理权限 - for menu in all_product_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=super_admin_role, - menu=menu - ) - - if created: - print(f"为超级管理员分配产品管理权限: {menu.menu_name}") - - # 为系统管理员分配产品管理权限 - for menu in all_product_menus: - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=system_admin_role, - menu=menu - ) - - if created: - print(f"为系统管理员分配产品管理权限: {menu.menu_name}") - - -def assign_user_roles(superuser, roles): - """ - 分配用户角色 - """ - from hertz_studio_django_auth.models import HertzUserRole - - print("正在分配用户角色...") - - # 为超级管理员分配超级管理员角色 - super_admin_role = roles['super_admin'] - - user_role, created = HertzUserRole.objects.get_or_create( - user=superuser, - role=super_admin_role - ) - - if created: - print(f"为用户 {superuser.username} 分配角色: {super_admin_role.role_name}") - else: - print(f"用户 {superuser.username} 已拥有角色: {super_admin_role.role_name}") - - -def sync_generated_menus(): - """ - 同步代码生成器生成的菜单权限 - 动态扫描所有pending_menus_*.py文件 - """ - print("正在检查是否有新生成的菜单需要同步...") - - import glob - import importlib.util - from hertz_studio_django_auth.models import HertzMenu - - # 动态扫描所有pending_menus_*.py文件 - project_root = Path(__file__).parent - pending_files = list(project_root.glob('pending_menus*.py')) - - if not pending_files: - print("没有待同步的菜单文件") - return {} - - all_created_menus = {} - total_synced_count = 0 - - # 首先获取已存在的菜单,用于父级菜单查找 - existing_menus = {menu.menu_code: menu for menu in HertzMenu.objects.all()} - - for pending_file in pending_files: - print(f"处理菜单文件: {pending_file.name}") - - try: - # 动态导入菜单配置文件 - module_name = pending_file.stem # 获取不带扩展名的文件名 - spec = importlib.util.spec_from_file_location(module_name, pending_file) - pending_menus_module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(pending_menus_module) - - if not hasattr(pending_menus_module, 'pending_menus'): - print(f"文件 {pending_file.name} 中没有找到 pending_menus 变量") - continue - - pending_menus = pending_menus_module.pending_menus - - # 添加到菜单配置中 - add_new_menus(pending_menus) - - # 同步到数据库 - created_menus = {} - synced_count = 0 - - for menu_data in pending_menus: - parent_code = menu_data.get('parent_code') - parent_menu = None - - # 先从新创建的菜单中查找父级菜单 - if parent_code and parent_code in created_menus: - parent_menu = created_menus[parent_code] - # 再从已存在的菜单中查找父级菜单 - elif parent_code and parent_code in existing_menus: - parent_menu = existing_menus[parent_code] - # 最后从所有已创建的菜单中查找 - elif parent_code and parent_code in all_created_menus: - parent_menu = all_created_menus[parent_code] - - menu, created = HertzMenu.objects.get_or_create( - menu_code=menu_data['menu_code'], - defaults={ - **{k: v for k, v in menu_data.items() if k != 'parent_code'}, - 'parent_id': parent_menu - } - ) - - # 如果菜单已存在但parent_id不同,更新parent_id - if not created and menu.parent_id != parent_menu: - menu.parent_id = parent_menu - menu.save() - - created_menus[menu.menu_code] = menu - all_created_menus[menu.menu_code] = menu - - if created: - print(f"新菜单同步成功: {menu.menu_name}") - synced_count += 1 - else: - print(f"菜单已存在: {menu.menu_name}") - - total_synced_count += synced_count - print(f"文件 {pending_file.name} 处理完成,同步了 {synced_count} 个新菜单") - - except Exception as e: - print(f"处理文件 {pending_file.name} 失败: {e}") - continue - - print(f"菜单同步完成,总共同步了 {total_synced_count} 个新菜单") - return all_created_menus - - -def assign_generated_menu_permissions(generated_menus): - """ - 为生成的菜单分配权限给超级管理员和系统管理员 - """ - if not generated_menus: - return - - from hertz_studio_django_auth.models import HertzRole, HertzRoleMenu - - print("正在为生成的菜单分配权限...") - - try: - # 获取角色 - super_admin_role = HertzRole.objects.get(role_code='super_admin') - system_admin_role = HertzRole.objects.get(role_code='system_admin') - - # 为超级管理员分配所有生成的菜单权限 - for menu in generated_menus.values(): - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=super_admin_role, - menu=menu - ) - - if created: - print(f"为超级管理员分配权限: {menu.menu_name}") - - # 为系统管理员分配生成的菜单权限 - for menu in generated_menus.values(): - role_menu, created = HertzRoleMenu.objects.get_or_create( - role=system_admin_role, - menu=menu - ) - - if created: - print(f"为系统管理员分配权限: {menu.menu_name}") - - except Exception as e: - print(f"分配生成菜单权限失败: {e}") - - -def create_menu_generator_command(): - """ - 创建菜单生成器命令行工具 - """ - generator_script = '''#!/usr/bin/env python -""" -菜单生成器命令行工具 -用于快速生成菜单配置和权限同步 -""" - -import os -import sys -import argparse -import django -from pathlib import Path - -# 添加项目路径 -project_root = os.path.dirname(os.path.abspath(__file__)) -sys.path.insert(0, project_root) - -# 设置Django环境 -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') -django.setup() - -from hertz_studio_django_utils.code_generator.menu_generator import MenuGenerator - - -def generate_crud_menu(args): - """生成CRUD菜单""" - generator = MenuGenerator() - - operations = args.operations.split(',') if args.operations else ['list', 'create', 'update', 'delete'] - - menus = generator.generate_menu_config( - module_name=args.module_name, - model_name=args.model_name, - operations=operations, - parent_code=args.parent_code, - menu_prefix=args.prefix, - sort_order=args.sort_order, - icon=args.icon - ) - - # 保存到待同步文件 - pending_file = os.path.join(project_root, 'pending_menus.py') - with open(pending_file, 'w', encoding='utf-8') as f: - f.write('# 待同步的菜单配置\\n') - f.write('pending_menus = [\\n') - for menu in menus: - f.write(' {\\n') - for key, value in menu.items(): - if isinstance(value, str): - f.write(f" '{key}': '{value}',\\n") - elif value is None: - f.write(f" '{key}': None,\\n") - else: - f.write(f" '{key}': {value},\\n") - f.write(' },\\n') - f.write(']\\n') - - print(f"已生成 {len(menus)} 个菜单配置,保存到 pending_menus.py") - print("请重启服务器以同步菜单到数据库") - - -def menu_generator_main(): - parser = argparse.ArgumentParser(description='菜单生成器') - subparsers = parser.add_subparsers(dest='command', help='可用命令') - - # CRUD菜单生成命令 - crud_parser = subparsers.add_parser('crud', help='生成CRUD菜单') - crud_parser.add_argument('module_name', help='模块名称(中文)') - crud_parser.add_argument('model_name', help='模型名称(英文)') - crud_parser.add_argument('--parent-code', default='system', help='父级菜单代码') - crud_parser.add_argument('--prefix', default='system', help='菜单前缀') - crud_parser.add_argument('--operations', help='操作列表(逗号分隔)') - crud_parser.add_argument('--sort-order', type=int, default=1, help='排序') - crud_parser.add_argument('--icon', help='图标') - - args = parser.parse_args() - - if args.command == 'crud': - generate_crud_menu(args) - else: - parser.print_help() - - -if __name__ == "__main__": - menu_generator_main() -''' - - script_path = os.path.join(Path(__file__).parent, 'generate_menu.py') - with open(script_path, 'w', encoding='utf-8') as f: - f.write(generator_script) - - print(f"菜单生成器命令行工具已创建: {script_path}") - - -def init_database(): - """ - 数据库初始化主函数 - """ - print("开始初始化数据库...") - print("=" * 50) - - try: - with transaction.atomic(): - superuser = init_superuser() - - # 2. 初始化部门 - departments = init_departments() - - # 3. 初始化菜单 - menus = init_menus() - - # 4. 同步代码生成器生成的菜单 - generated_menus = sync_generated_menus() - - # 5. 初始化角色 - roles = init_roles() - - # 6. 分配角色菜单权限 - assign_role_menus(roles, menus) - - # 7. 为生成的菜单分配权限 - assign_generated_menu_permissions(generated_menus) - - assign_user_roles(superuser, roles) - demo_user = init_demo_user() - - # 9. 初始化YOLO模块数据 - # init_yolo_data() - - # 10. 创建菜单生成器命令行工具 - create_menu_generator_command() - - # 11. 删除待同步文件(如果存在) - sync_file_path = os.path.join(Path(__file__).parent, 'pending_menus.py') - if os.path.exists(sync_file_path): - os.remove(sync_file_path) - print("已删除待同步菜单文件") - - print("=" * 50) - print("数据库初始化完成!") - print("") - print("超级管理员账号信息:") - print(f"用户名: hertz") - print(f"密码: hertz") - print(f"邮箱: admin@hertz.com") - print("") - print("") - print("菜单生成器工具:") - print(f"使用命令: python generate_menu.py crud <模块名> <模型名>") - print("") - print("请妥善保管管理员账号信息!") - print("") - print("普通用户账号信息:") - print("用户名: demo") - print("密码: 123456") - - except Exception as e: - print(f"数据库初始化失败: {str(e)}") - sys.exit(1) - - -# 简化的文件监听实现,避免watchdog的兼容性问题 -class SimpleFileWatcher: - """简单的文件监听器""" - - def __init__(self, paths, callback): - self.paths = paths - self.callback = callback - self.file_times = {} - self.running = False - self.thread = None - self.last_check = 0 - self.check_interval = 1 # 检查间隔(秒) - - def _scan_files(self): - """扫描文件变化""" - for path in self.paths: - if not path.exists(): - continue - - for file_path in path.rglob('*'): - if file_path.is_file() and file_path.suffix in ['.py', '.html', '.css', '.js']: - try: - mtime = file_path.stat().st_mtime - if str(file_path) in self.file_times: - if mtime > self.file_times[str(file_path)]: - print(f"\n📝 检测到文件变化: {file_path}") - print("🔄 正在重启服务器...") - self.file_times[str(file_path)] = mtime - self.callback() - return - else: - self.file_times[str(file_path)] = mtime - except (OSError, PermissionError): - continue - - def _watch_loop(self): - """监听循环""" - while self.running: - try: - current_time = time.time() - if current_time - self.last_check >= self.check_interval: - self._scan_files() - self.last_check = current_time - time.sleep(0.1) - except Exception: - continue - - def start(self): - """启动监听""" - if not self.running: - self.running = True - # 初始化文件时间戳 - self._scan_files() - self.thread = threading.Thread(target=self._watch_loop, daemon=True) - self.thread.start() - - def stop(self): - """停止监听""" - self.running = False - if self.thread and self.thread.is_alive(): - try: - self.thread.join(timeout=1) - except KeyboardInterrupt: - # 忽略键盘中断异常,直接继续执行 - pass - -class ServerManager: - """服务器管理器""" - - def __init__(self, host: str = '0.0.0.0', port: int = 8000): - self.process = None - self.watcher = None - self.base_dir = Path(__file__).resolve().parent - self.running = True - self.host = host - self.port = int(port) - - def start_server(self): - """启动服务器进程""" - if self.process: - self.stop_server() - - cmd = [ - sys.executable, '-m', 'daphne', - '-b', self.host, - '-p', str(self.port), - 'hertz_server_django.asgi:application' - ] - - try: - self.process = subprocess.Popen( - cmd, - cwd=self.base_dir, - creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if sys.platform == 'win32' else 0 - ) - print("✅ 服务器启动成功") - return True - except Exception as e: - print(f"❌ 服务器启动失败: {e}") - return False - - def stop_server(self): - """停止服务器进程""" - if self.process: - try: - if sys.platform == 'win32': - # Windows系统使用taskkill命令 - subprocess.run(['taskkill', '/F', '/T', '/PID', str(self.process.pid)], - capture_output=True) - else: - self.process.terminate() - self.process.wait(timeout=5) - except Exception: - pass - finally: - self.process = None - - def restart_server(self): - """重启服务器""" - self.stop_server() - time.sleep(1) # 延迟确保端口释放 - if self.running: - self.start_server() - - def start_file_watcher(self): - """启动文件监听器""" - watch_paths = [ - self.base_dir, - ] - - existing_paths = [path for path in watch_paths if path.exists()] - if existing_paths: - self.watcher = SimpleFileWatcher(existing_paths, self.restart_server) - self.watcher.start() - - for path in existing_paths: - print(f"👀 监听目录: {path.name}") - - def stop_file_watcher(self): - """停止文件监听器""" - if self.watcher: - self.watcher.stop() - - def shutdown(self): - """关闭所有服务""" - self.running = False - self.stop_server() - try: - self.stop_file_watcher() - except KeyboardInterrupt: - # 忽略关闭过程中的键盘中断异常 - pass - -def check_database_exists(): - """ - 检查数据库是否存在 - """ - from django.conf import settings - - db_config = settings.DATABASES['default'] - - if db_config['ENGINE'] == 'django.db.backends.sqlite3': - db_path = Path(db_config['NAME']) - return db_path.exists() - else: - try: - from django.db import connection - connection.ensure_connection() - return True - except Exception: - return False - -def create_mysql_database_if_missing(): - from django.conf import settings - db = settings.DATABASES['default'] - if db['ENGINE'] != 'django.db.backends.mysql': - return False - name = db['NAME'] - host = db.get('HOST') or 'localhost' - user = db.get('USER') or 'root' - password = db.get('PASSWORD') or '' - port = int(db.get('PORT') or 3306) - try: - import MySQLdb - except Exception: - return False - try: - conn = MySQLdb.connect(host=host, user=user, passwd=password, port=port) - cur = conn.cursor() - cur.execute(f"CREATE DATABASE IF NOT EXISTS `{name}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") - conn.commit() - cur.close() - conn.close() - return True - except Exception: - return False - - -def run_migrations(): - """ - 执行数据库迁移 - """ - print("正在检查并执行数据库迁移...") - - try: - # 执行makemigrations - print("执行makemigrations...") - from django.core.management import execute_from_command_line - execute_from_command_line(['manage.py', 'makemigrations']) - - # 执行migrate - print("执行migrate...") - execute_from_command_line(['manage.py', 'migrate']) - - print("数据库迁移完成") - return True - except Exception as e: - print(f"数据库迁移失败: {str(e)}") - return False - - -def check_initial_data(): - """ - 检查是否存在初始数据 - """ - from hertz_studio_django_auth.models import HertzUser, HertzMenu - - try: - # 检查是否存在超级管理员用户 - has_superuser = HertzUser.objects.filter(username='hertz').exists() - - # 检查是否存在工作室菜单(新增的菜单) - has_studio_menu = HertzMenu.objects.filter(menu_code='studio').exists() - - # 只有当超级管理员和工作室菜单都存在时,才认为初始数据完整 - return has_superuser and has_studio_menu - except Exception: - # 如果表不存在或其他错误,返回False - return False - - -def main(): - """ - 主函数 - 自动化数据库检查、迁移、初始化和服务器启动 - """ - print("🚀 启动Hertz Server Django") - print("📋 开始自动化启动流程...") - print("\n" + "=" * 50) - - # 设置Django环境 - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') - django.setup() - - # 步骤0: 扫描并注册新应用 - print("🔍 步骤0: 扫描并注册新的Django应用...") - new_apps = scan_and_register_new_apps() - - # 如果有新应用注册,需要重新加载Django设置 - if new_apps: - print("🔄 重新加载Django设置...") - # 重新导入settings模块 - import importlib - from django.conf import settings - - # 重新导入settings模块 - settings_module = importlib.import_module('hertz_server_django.settings') - importlib.reload(settings_module) - - # 重新配置Django - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'hertz_server_django.settings') - django.setup() - - # 为新应用执行迁移 - execute_migrations_for_new_apps(new_apps) - - # 步骤1: 检查数据库是否存在 - print("📊 步骤1: 检查数据库状态...") - if not check_database_exists(): - print("❌ 数据库不存在,需要创建") - created = create_mysql_database_if_missing() - need_migration = True - else: - print("✅ 数据库文件存在") - need_migration = False - - # 步骤2: 执行数据库迁移(如果需要) - if need_migration or not check_initial_data(): - print("\n📋 步骤2: 执行数据库迁移...") - if not run_migrations(): - print("❌ 数据库迁移失败,无法继续") - sys.exit(1) - else: - print("\n✅ 步骤2: 数据库迁移已完成") - - # 步骤3: 检查并初始化数据 - print("\n📋 步骤3: 检查初始数据...") - if not check_initial_data(): - print("❌ 缺少初始数据,开始初始化") - init_database() - else: - print("✅ 初始数据已存在") - # 即使初始数据存在,也要同步生成的菜单 - sync_generated_menus() - - print("\n" + "=" * 50) - print("✅ 数据库准备完成!") - - # 步骤4: 启动服务器 - print("\n📋 步骤4: 启动服务器...") - print("🚀 启动Hertz Server Django (支持HTTP + WebSocket + 热重启)") - parser = argparse.ArgumentParser(add_help=False) - parser.add_argument('--port', type=int) - args, _ = parser.parse_known_args() - env_port = os.environ.get('PORT') or os.environ.get('DJANGO_PORT') - try: - env_port_int = int(env_port) if env_port is not None else None - except ValueError: - env_port_int = None - port = args.port or env_port_int or 8000 - print("📡 使用Daphne ASGI服务器") - print(f"🌐 HTTP服务: http://127.0.0.1:{port}/") - print(f"🔌 WebSocket服务: ws://127.0.0.1:{port}/ws/") - print("🔥 自动热重启: 已启用") - print("\n按 Ctrl+C 停止服务器\n") - - # 检查依赖 - try: - import daphne - except ImportError: - print("❌ 错误: 未安装daphne") - print("请运行: pip install daphne") - return - - try: - import watchdog - except ImportError: - print("❌ 错误: 未安装watchdog") - print("请运行: pip install watchdog") - return - - # 创建服务器管理器 - server_manager = ServerManager(port=port) - - try: - # 启动服务器 - if server_manager.start_server(): - # 启动文件监听器 - server_manager.start_file_watcher() - - # 保持主线程运行 - while server_manager.running: - try: - time.sleep(1) - except KeyboardInterrupt: - print("\n🛑 收到停止信号,正在关闭服务器...") - break - - except Exception as e: - print(f"❌ 启动失败: {e}") - finally: - server_manager.shutdown() - print("👋 服务器已停止") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/static/arial.ttf b/static/arial.ttf deleted file mode 100644 index 1e519bc..0000000 Binary files a/static/arial.ttf and /dev/null differ diff --git a/static/ffmpeg-7.1.1-essentials_build.zip b/static/ffmpeg-7.1.1-essentials_build.zip deleted file mode 100644 index 75d491b..0000000 Binary files a/static/ffmpeg-7.1.1-essentials_build.zip and /dev/null differ diff --git a/后端部署教程(开发人员用).md b/后端部署教程(开发人员用).md deleted file mode 100644 index 602a242..0000000 --- a/后端部署教程(开发人员用).md +++ /dev/null @@ -1,124 +0,0 @@ -# 后端部署教程 - - - -## 一、**环境要求** - -- `Python 3.10+`(建议 3.12.3) -- 操作系统:Windows -- redis:redis版本建议用5.0.10(默认地址 `redis://127.0.0.1:6379`) - - - -## 二、 获取机器码并激活 - -### 1)获取机器码: - -```python -python get_machine_code.py -``` - -运行后会获得一个机器码,例如:HERTZ_STUDIO_XXXXXXXXXXXXXXXX - -将机器码发给相关技术人员进行激活! - - - -## 三、初始化配置 - -### 1)双击运行init_backend.bat - -注:此操作会自动创建环境并下载依赖以及完成数据集迁移 - - - -## 四、配置Django - -### 1)将下载的app注册到django项目中 - -在根目录下面的hertz_server_django文件夹下面的setting文件(如下图)中注册app - -如下图我用到了notice、ai、wiki等模块 -![app注册](docs/img/img_1.png) - -### 2)注册app后配置路由 - -在根目录下面的hertz_server_django文件夹下面的urls.py文件(如下图)中配置路由 - -配置路由参考,例如我要配置名为xxx的路由: - -```python -path('api/xxx/', include('hertz_studio_django_xxx.urls')), -``` - -如下图我配置了notice、ai、wiki等路由 -![路由配置](docs/img/img_2.png) - - - - - - -## **五、启动服务** - -- 通过python脚本启动(支持端口参数): - ```python - python start_server.py --port 8000 - ``` - -- 通过bat脚本启动(快捷) - -​ 双击 start_server.bat - - - -## 六、默认账号 - -- 超级管理员: - - 用户名:`hertz` - - 密码:`hertz` -- 普通用户 - - 用户名:`demo` - - 密码:`123456` - - - -## **七、问题排查** - -### 1)`daphne` 或 `watchdog` 未安装: - -运行:`pip install daphne watchdog -i https://pypi.tuna.tsinghua.edu.cn/simple`(`start_server.py:1235-1248` 有依赖检查)。 - -### 2)Redis 未运行: - -安装并启动 Redis,或调整 `REDIS_URL` 指向可用实例。 - -### 3)视频检测结果展示不了 - -需配置ffmpeg环境,参考文章:https://blog.csdn.net/csdn_yudong/article/details/129182648 - -ffmpeg压缩包在根目录下面的static目录下,如下图。 - ![ffmpeg配置](docs/img/img_3.png) - - - - -## **八、项目结构** - -- 核心配置:`hertz_server_django/settings.py`、`hertz_server_django/urls.py` -- 启动脚本:`start_server.py` -- 依赖清单:`requirements.txt` -- 静态资源:`static/`,媒体资源:`media/` - - - -## 九、**快速启动** - -- 安装依赖: - - pip install -r requirements.txt - - pip install -r hertz.txt -i https://hertz:hertz@hzpypi.hzsystems.cn/simple/ - -- 启动服务:`python start_server.py --port 8000` - diff --git a/启动教程(客户用).txt b/启动教程(客户用).txt deleted file mode 100644 index 37638aa..0000000 --- a/启动教程(客户用).txt +++ /dev/null @@ -1,27 +0,0 @@ -请确保python环境和node环境都已经配置完成!!! - -请确保python环境和node环境都已经配置完成!!! - -请确保python环境和node环境都已经配置完成!!! -==========0. 获取机器码(若激活请跳过)========== -双击get_machine_code.bat -将机器码发给售后或相关技术人员激活! - -==========1. 项目初始化(若初始化请直接跳转到启动)========== -后端初始化:双击init_backend.bat -前端初始化:双击init_frontend.bat - - -==========2. 项目启动========== -后端启动:双击launch_backend.bat -前端启动:双击launch_frontend.bat - - - -===========3. 默认账号========== -超级管理员: - 用户名:hertz - 密码:hertz -普通用户: - 用户名:demo - 密码:123456 \ No newline at end of file