Skip to content

上传大文件

技术:fastapi + websocket + vue3

fastapi 代码

fastapi 代码
py
from fastapi import (
    APIRouter,
    WebSocket,
    WebSocketDisconnect,
    Depends,
    status,
)
from typing import List, Dict, Any
import time
import uuid
import aiofiles
from sqlmodel import Session
from pathlib import Path
import os
import json
import orjson
from db.db import get_db
from db.services.ppt import ppt_service

router = APIRouter(
    prefix="/ws",
    tags=["websocket"],
    responses={404: {"description": "Not found"}},
)

@router.websocket("/upload/{file_type}/{total_num}")
async def websocket_upload_endpoint(
    websocket: WebSocket,
    file_type: str,
    total_num: int,
    session: Session = Depends(get_db),
):
    print(f"file_type: {file_type}")
    start_time = time.time()
    await websocket.accept()
    try:
        # 获取token参数
        token = websocket.query_params.get("token")
        if not token:
            await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Token 缺失")

        # 验证token,拿到 user_id
        user_id = 'user_id_12333'
        print(f"user_id: {user_id}")

        if user_id is None:
            await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Token 无效")            

        # 验证文件格式
        if not file_type in ["ppt", "pptx"]:
            await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="仅支持PPT和PPTX文件格式")

        # 文件名和路径
        saved_filename = f"{uuid.uuid4().hex}.{file_type}"
        file_path = Path(f"{user_id}/{saved_filename}")
        # 确保目录存在
        file_path.parent.mkdir(parents=True, exist_ok=True)

        msg_json = {
            "message": "连接成功",
            "user_id": user_id,
            "file_path": str(file_path),
        }
        print(f"msg_json: {msg_json}")
        await websocket.send_json(msg_json)

        num = 0
        jsonData = {}

        async with aiofiles.open(f"{file_path}", "wb") as f:
            while True:
                message = await websocket.receive()
                if message['type'] == 'websocket.receive':
                    if 'text' in message:
                        json_data = json.loads(message['text'])
                        print(f"---000---data: {json_data}")
                        if json_data:
                            jsonData = json_data
                    elif 'bytes' in message:
                        binary_data = message['bytes']
                        if binary_data:
                            num += 1
                            await f.write(binary_data)
                            # 计算进度
                            progress = int(num / total_num * 100)
                            msg_json = {
                                "message": f"{num}/{total_num}上传完成",
                                "user_id": user_id,
                                "file_path": str(file_path),
                                "time": time.time() - start_time,
                                "total_num": total_num,
                                "num": num,
                                "progress": progress,
                            }
                            print(f"msg_json: {msg_json}")
                            await websocket.send_json(msg_json)

                            if total_num == num:
                                break

        print(f"file_path: {file_path}")
        await f.close()
        # 发送完成消息

        # 创建PPT记录
        ppt_data = {
            "user_id": user_id,
            "filename": f"{jsonData['file_name']}",
            "file_path": str(file_path),
            "file_size": os.path.getsize(file_path),
            "content_type": f"{jsonData['file_type']}"
            or "application/vnd.openxmlformats-officedocument.presentationml.presentation",
        }
        print(f"ppt_data: {ppt_data}")
        # 保存到数据库
        ppt = ppt_service.create(session, ppt_data)
        ppt_json_str = orjson.dumps({"code": 200, "message": "上传成功", "data": ppt.dict()}).decode('utf-8')
        await websocket.send_text(ppt_json_str)

    except WebSocketDisconnect:
        print("Client disconnected")
    finally:
        await websocket.close()

js 代码

js 代码
js
import { baseUrl } from "@/utils/http";
import SparkMD5 from 'spark-md5'

function getWsUrl(url) {
  if (!url.startsWith('http://') && !url.startsWith('https://')) {
    url = `${baseUrl}${url}`;
    url = url.replace('http://', 'ws://').replace('https://', 'wss://');
    url = `${url}?token=${getToken()}`;
  }
  return url;
}

export const useUploadHook = (file, successCallback = () => { }, progressCallback = () => { }) => {

  let ws = null;

  // 文件块
  const chunkSize = 1024 * 1024 * 1; // 1MB
  const totalChunks = Math.ceil(file.size / chunkSize);
  let chunks = []; // 文件块
  let chunkIndex = 0; // 当前上传的块索引

  // 文件切块
  function sliceFile() {
    for (let i = 0; i < totalChunks; i++) {
      const chunk = file.slice(i * chunkSize, (i + 1) * chunkSize);
      chunks.push(chunk);
      console.log('文件切块', i, chunk);
    }
  }

  async function startWithWs() {
    // 拿 file_name 的后缀
    const fileType = file.name.split('.').pop()
    const url = getWsUrl(`/ws/upload/${fileType}/${totalChunks}`)
    ws = new WebSocket(url)
    ws.onopen = async () => {
      console.log('连接成功')
      
      ws.send(JSON.stringify({
        file_name: file.name,
        file_size: file.size,
        file_type: file.type,
      }))
      
      chunkIndex = 0
      await sendData(chunkIndex)

    }
    ws.onmessage = async (event) => {
      console.log('收到消息:', event.data);
      const data = JSON.parse(event.data);
      if (data.code === 200) {
        successCallback(data)
      }
      else if (data?.num) {
        progressCallback(data)
        chunkIndex++
        await sendData(chunkIndex)
      }
    }
    ws.onerror = (error) => {
      console.error('WebSocket 错误:', error);
    }
    ws.onclose = (event) => {
      console.log('连接关闭:', event);
      console.log('连接关闭:', event.data);
    }
  }

  async function sendData(i) {
    const chunk = chunks[i];
    const chunkData = await chunk.arrayBuffer();
    ws.send(chunkData)
  }
  
  return {
    sliceFile,
    startUpload,
    startWithWs,
  }
}