Gemini格式API转OpenAI格式

使用场景

我的vscode上的Cline想使用Gemini模型就的开代理,虽然可以在settings.json里设置代理,但是在这里设置了就给我全部代码挂了代理,而且还不稳定,所以写了一个Gemini格式API转OpenAI格式的中转,这样就可以在Cline上快乐的使用Gemini模型了,参考了openai-gemini项目,但是这个项目部署在workers上很容易导致gemini账户被封,所以我把他部署在了vps上

图片[1]|Gemini格式API转OpenAI格式 - 5v|5v

为什么不用DeepSeek的原因,看下面的对比,因为代码都比较长,所以选用Gemini比较靠谱点

Gemini-2.0-Flash

  • 输入上下文窗口:支持1,000,000 tokens(约166.7万中文字符,按1 token≈1.67中文字符换算)。
  • 该模型通过优化的长上下文处理能力,可分析长达140万单词的文本或2小时视频内容。证据显示其输入容量在多个版本(如Flash和Flash-Lite)中均保持100万tokens的一致性。

DeepSeek-R1

  • 输入上下文窗口:支持128,000 tokens(约21.3万中文字符)。
  • 尽管早期资料提到64K tokens的上下文长度,但最新对比报告(2024年12月后)明确其输入容量为128K tokens。这可能因模型升级或不同配置导致差异,需以最新数据为准。

关键差异

  • 容量对比:Gemini-2.0-Flash的输入容量是DeepSeek-R1的约7.8倍(1M vs. 128K tokens)。
  • 多模态支持:Gemini-2.0-Flash支持图像、视频等多模态输入,而DeepSeek-R1仅支持文本。
  • 成本与场景:Gemini-2.0-Flash-Lite在同等输入量下成本更低(每百万tokens约0.075美元),适合高频长文本任务;DeepSeek-R1则以开源和复杂推理见长

 

python代码如下:

from flask import Flask, request, jsonify, Response, stream_with_context
import requests
import os
import json
import base64
import time
import random
import string
from string import Template

app = Flask(__name__)

# 公共配置
BASE_URL = "https://generativelanguage.googleapis.com"
API_VERSION = "v1beta"
API_CLIENT = "genai-js/0.21.0"
DEFAULT_MODEL = "gemini-1.5-pro-latest"
DEFAULT_EMBEDDINGS_MODEL = "text-embedding-004"


# 辅助函数
def fix_cors(response):
    response.headers['Access-Control-Allow-Origin'] = '*'
    return response

class HttpError(Exception):
    def __init__(self, message, status_code):
        super().__init__(message)
        self.status_code = status_code

def generate_chatcmpl_id():
    return 'chatcmpl-' + ''.join(random.choices(string.ascii_letters + string.digits, k=29))

# 中间件处理
@app.after_request
def add_cors_headers(response):
    return fix_cors(response)

# 路由处理
@app.route('/v1/chat/completions', methods=['POST', 'OPTIONS'])
def handle_completions():
    if request.method == 'OPTIONS':
        return handle_options()
    
    auth = request.headers.get('Authorization')
    api_key = auth.split(' ')[1] if auth else None
    
    try:
        data = request.json
        return handle_completions_request(data, api_key)
    except Exception as e:
        return jsonify({'error': str(e)}), getattr(e, 'status_code', 500)

@app.route('/v1/embeddings', methods=['POST', 'OPTIONS'])
def handle_embeddings():
    if request.method == 'OPTIONS':
        return handle_options()
    
    auth = request.headers.get('Authorization')
    api_key = auth.split(' ')[1] if auth else None
    
    try:
        data = request.json
        return handle_embeddings_request(data, api_key)
    except Exception as e:
        return jsonify({'error': str(e)}), getattr(e, 'status_code', 500)

@app.route('/v1/models', methods=['GET', 'OPTIONS'])
def handle_models():
    if request.method == 'OPTIONS':
        return handle_options()
    
    auth = request.headers.get('Authorization')
    api_key = auth.split(' ')[1] if auth else None
    
    try:
        return handle_models_request(api_key)
    except Exception as e:
        return jsonify({'error': str(e)}), getattr(e, 'status_code', 500)

def handle_options():
    return Response(headers={
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': '*',
        'Access-Control-Allow-Headers': '*'
    })

# 模型列表处理
def handle_models_request(api_key):
    headers = {
        'x-goog-api-client': API_CLIENT,
        'Content-Type': 'application/json'
    }
    if api_key:
        headers['x-goog-api-key'] = api_key

    proxy = '127.0.0.1:10808'
    proxies = { 'http': proxy, 'https': proxy }
    
    response = requests.get(f"{BASE_URL}/{API_VERSION}/models", headers=headers)
    
    if response.status_code == 200:
        models = response.json().get('models', [])
        transformed = {
            "object": "list",
            "data": [
                {
                    "id": model['name'].replace("models/", ""),
                    "object": "model",
                    "created": 0,
                    "owned_by": ""
                } for model in models
            ]
        }
        return jsonify(transformed)
    
    return Response(response.content, status=response.status_code, headers=dict(response.headers))

# 嵌入处理
def handle_embeddings_request(data, api_key):
    model = data.get('model', DEFAULT_EMBEDDINGS_MODEL)
    inputs = data.get('input', [])
    
    if not isinstance(inputs, list):
        inputs = [inputs]
    
    model_path = f"models/{model}" if not model.startswith("models/") else model
    
    payload = {
        "requests": [
            {
                "model": model_path,
                "content": {"parts": [{"text": text}]},
                "outputDimensionality": data.get('dimensions')
            } for text in inputs
        ]
    }
    
    headers = {
        'x-goog-api-client': API_CLIENT,
        'Content-Type': 'application/json'
    }
    if api_key:
        headers['x-goog-api-key'] = api_key

    proxy = '127.0.0.1:10808'
    proxies = { 'http': proxy, 'https': proxy }
    
    response = requests.post(
        f"{BASE_URL}/{API_VERSION}/{model_path}:batchEmbedContents",
        json=payload,
        headers=headers,
        
    )
    
    if response.status_code == 200:
        result = response.json()
        transformed = {
            "object": "list",
            "data": [
                {
                    "object": "embedding",
                    "index": idx,
                    "embedding": emb['values']
                } for idx, emb in enumerate(result.get('embeddings', []))
            ],
            "model": model
        }
        return jsonify(transformed)
    
    return Response(response.content, status=response.status_code, headers=dict(response.headers))

# 聊天完成处理
def handle_completions_request(data, api_key):
    model = data.get('model', DEFAULT_MODEL)
    stream = data.get('stream', False)
    
    model_path = model if model.startswith("models/") else f"models/{model}"
    task = "streamGenerateContent" if stream else "generateContent"
    url = f"{BASE_URL}/{API_VERSION}/{model_path}:{task}"
    
    if stream:
        url += "?alt=sse"
    
    headers = {
        'x-goog-api-client': API_CLIENT,
        'Content-Type': 'application/json'
    }
    if api_key:
        headers['x-goog-api-key'] = api_key
    
    transformed = transform_request(data)
    proxy = '127.0.0.1:10808'
    proxies = { 'http': proxy, 'https': proxy }
    
    response = requests.post(url, json=transformed, headers=headers, stream=stream)
    
    if stream:
        def generate():
            buffer = ""
            for chunk in response.iter_content(chunk_size=None):
                if chunk:
                    buffer += chunk.decode()
                    while 'data: ' in buffer:
                        parts = buffer.split('data: ', 1)
                        if len(parts) > 1:
                            data, buffer = parts[1], ''
                            try:
                                json_data = json.loads(data)
                                yield process_stream_data(json_data, model)
                            except json.JSONDecodeError:
                                buffer = data + buffer  # Append back to buffer if not a complete JSON
                        else:
                            buffer = parts[0]
            yield "data: [DONE]\n\n"

        return Response(stream_with_context(generate()), mimetype='text/event-stream')
    
    else:
        if response.status_code == 200:
            data = response.json()
            return jsonify(process_completion_data(data, model))
        return Response(response.content, status=response.status_code, headers=dict(response.headers))

# 请求转换逻辑
def transform_request(data):
    messages = data.get('messages', [])
    system_instruction = None
    contents = []
    
    for msg in messages:
        if msg['role'] == 'system':
            system_instruction = transform_message(msg)
        else:
            role = 'model' if msg['role'] == 'assistant' else 'user'
            contents.append(transform_message(msg, role))
    
    generation_config = {
        'stopSequences': data.get('stop'),
        'candidateCount': data.get('n'),
        'maxOutputTokens': data.get('max_tokens'),
        'temperature': data.get('temperature'),
        'topP': data.get('top_p'),
        'topK': data.get('top_k')
    }
    
    return {
        'system_instruction': system_instruction,
        'contents': contents,
        'generationConfig': generation_config,
        'safetySettings': [
            {"category": cat, "threshold": "BLOCK_NONE"}
            for cat in [
                "HARM_CATEGORY_HATE_SPEECH",
                "HARM_CATEGORY_SEXUALLY_EXPLICIT",
                "HARM_CATEGORY_DANGEROUS_CONTENT",
                "HARM_CATEGORY_HARASSMENT",
                "HARM_CATEGORY_CIVIC_INTEGRITY"
            ]
        ]
    }

def transform_message(msg, role=None):
    parts = []
    content = msg.get('content', '')
    
    if isinstance(content, str):
        parts.append({'text': content})
    else:
        for item in content:
            if item['type'] == 'text':
                parts.append({'text': item['text']})
            elif item['type'] == 'image_url':
                parts.append(parse_image(item['image_url']['url']))
    
    return {
        'role': role or msg['role'],
        'parts': parts
    }

def parse_image(url):
    if url.startswith(('http://', 'https://')):
        response = requests.get(url)
        response.raise_for_status()
        return {
            'inlineData': {
                'mimeType': response.headers['Content-Type'],
                'data': base64.b64encode(response.content).decode()
            }
        }
    else:
        header, data = url.split(',', 1)
        mime_type = header.split(':')[1].split(';')[0]
        return {
            'inlineData': {
                'mimeType': mime_type,
                'data': data
            }
        }

# 响应处理逻辑
def process_completion_data(data, model):
    candidates = data.get('candidates', [])
    usage = data.get('usageMetadata', {})
    
    return {
        'id': generate_chatcmpl_id(),
        'object': 'chat.completion',
        'created': int(time.time()),
        'model': model,
        'choices': [transform_candidate(c) for c in candidates],
        'usage': {
            'prompt_tokens': usage.get('promptTokenCount', 0),
            'completion_tokens': usage.get('candidatesTokenCount', 0),
            'total_tokens': usage.get('totalTokenCount', 0)
        }
    }

def transform_candidate(candidate):
    return {
        'index': candidate.get('index', 0),
        'message': {
            'role': 'assistant',
            'content': ''.join(part['text'] for part in candidate.get('content', {}).get('parts', []))
        },
        'finish_reason': candidate.get('finishReason', 'stop')
    }

def process_stream_data(data, model):
    candidate = data.get('candidates', [{}])[0]

    delta = {
        'content': ''.join(part.get('text', '') for part in candidate.get('content', {}).get('parts', []))
    }

    template = Template("data: $json_data\n\n")
    json_data = json.dumps({
        'id': generate_chatcmpl_id(),
        'object': 'chat.completion.chunk',
        'created': int(time.time()),
        'model': model,
        'choices': [{
            'index': candidate.get('index', 0),
            'delta': delta,
            'finish_reason': candidate.get('finishReason')
        }]
    })
    return template.substitute(json_data=json_data)

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5102, debug=False)

 

© 版权声明
THE END
喜欢就支持一下吧
点赞7 分享