{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 目标检测" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchvision\n", "from torchvision.models.detection.faster_rcnn import FastRCNNPredictor\n", "\n", "# load a model pre-trained pre-trained on COCO\n", "model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)\n", "\n", "# replace the classifier with a new one, that has\n", "# num_classes which is user-defined\n", "num_classes = 2 # 1 class (person) + background\n", "# get number of input features for the classifier\n", "in_features = model.roi_heads.box_predictor.cls_score.in_features\n", "# replace the pre-trained head with a new one\n", "model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.13 ('tvm38': conda)", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.8.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "3555d4060e1bb256f2e385b42190aa51debd92785a45a343e60f30a52ea749ac" } } }, "nbformat": 4, "nbformat_minor": 2 }