カスタムモデルの組み込み
カスタムモデルとは、ユーザー自身でデプロイまたは設定する必要があるLLMのことです。この記事では、Xinferenceモデルを例に、モデルプラグイン内でカスタムモデルを組み込む方法を解説します。
カスタムモデルには、デフォルトでモデルタイプとモデル名の2つのパラメータが含まれており、サプライヤのyamlファイルで定義する必要はありません。
サプライヤ設定ファイルでvalidate_provider_credentialを実装する必要はありません。Runtimeは、ユーザーが選択したモデルタイプまたはモデル名に基づいて、対応するモデルレイヤのvalidate_credentialsメソッドを自動的に呼び出して検証します。
カスタムモデルプラグインの組み込み
カスタムモデルを組み込むには、以下の手順に従います。
モデルサプライヤファイルの作成
組み込むカスタムモデルのモデルタイプを明確にします。
モデルタイプに応じたコードファイルの作成
モデルのタイプ(
llmやtext_embeddingなど)に応じて、コードファイルを作成します。各モデルタイプが独立した論理構造を持つようにすることで、保守と拡張が容易になります。モデルモジュールに基づいたモデル呼び出しコードの記述
対応するモデルタイプモジュールに、モデルタイプと同名のPythonファイル(例:llm.py)を作成します。ファイル内で、具体的なモデルロジックを実装するクラスを定義します。このクラスは、システムのモデルインターフェース仕様に準拠している必要があります。
プラグインのデバッグ
新たに追加されたサプライヤ機能について、ユニットテストと統合テストを作成し、すべての機能モジュールが期待どおりに動作することを確認します。
1. モデルサプライヤファイルの作成
プラグインプロジェクトの/providerパスに、xinference.yamlファイルを作成します。
Xinferenceは、LLM、Text Embedding、Rerankのモデルタイプをサポートしているため、xinference.yamlファイルにこれらのモデルタイプを含める必要があります。
サンプルコード:
次に、provider_credential_schemaフィールドを定義します。Xinferenceは、text-generation、embeddings、rerankingモデルをサポートしています。サンプルコードを以下に示します。
Xinferenceの各モデルでは、model_nameという名前を定義する必要があります。
Xinferenceモデルでは、ユーザーがモデルのローカルデプロイアドレスを入力する必要があります。プラグイン内では、Xinferenceモデルのローカルデプロイアドレス(server_url)とモデルUIDを入力できる場所を提供する必要があります。サンプルコードを以下に示します。
すべてのパラメータを入力すると、カスタムモデルサプライヤのyaml設定ファイルの作成が完了します。次に、設定ファイルで定義されたモデルに具体的な機能コードファイルを追加する必要があります。
2. モデルコードの記述
Xinferenceモデルサプライヤのモデルタイプには、llm、rerank、speech2text、ttsタイプが含まれています。そのため、/modelsパスに各モデルタイプの独立したグループを作成し、対応する機能コードファイルを作成する必要があります。
以下では、llmタイプを例に、llm.pyコードファイルの作成方法を説明します。コードを作成する際には、Xinference LLMクラスを作成する必要があります。名前はXinferenceAILargeLanguageModelとし、__base.large_language_model.LargeLanguageModel基底クラスを継承し、以下のメソッドを実装します。
LLMの呼び出し
LLM呼び出しの中核となるメソッドです。ストリーミングと同期の両方の戻り値をサポートします。
コードを実装する際には、同期戻り値とストリーミング戻り値で異なる関数を使用する必要があります。
Pythonでは、関数にyieldキーワードが含まれている場合、その関数はジェネレータ関数として認識され、戻り値の型はGeneratorに固定されます。したがって、同期戻り値とストリーミング戻り値をそれぞれ実装する必要があります。例えば、以下のサンプルコードを参照してください。
この例では、パラメータが簡略化されています。実際のコードを記述する際には、上記のパラメータリストを参照してください。
入力トークンの事前計算
モデルがトークンの事前計算インターフェースを提供していない場合は、0を返すことができます。
直接0を返したくない場合は、self._get_num_tokens_by_gpt2(text: str)メソッドを使用してトークンを計算できます。このメソッドはAIModel基底クラスにあり、GPT-2のTokenizerを使用して計算を行います。ただし、あくまで代替手段であり、計算結果には誤差が生じる可能性があることに注意してください。
モデルの認証情報の検証
サプライヤの認証情報の検証と同様に、ここでは個々のモデルを検証します。
モデルパラメータのスキーマ
事前定義されたモデルタイプとは異なり、YAMLファイルにモデルがサポートするパラメータが事前に定義されていないため、モデルパラメータのスキーマを動的に生成する必要があります。
例えば、Xinferenceは
max_tokens、temperature、top_pの3つのモデルパラメータをサポートしています。ただし、サプライヤによっては、モデルごとに異なるパラメータをサポートする場合があります(例:OpenLLM)。例として、サプライヤ
OpenLLMのモデルAはtop_kパラメータをサポートしていますが、モデルBはサポートしていません。この場合、各モデルに対応するパラメータスキーマを動的に生成する必要があります。以下にサンプルコードを示します。呼び出し例外エラーのマッピング
モデルの呼び出し時に例外が発生した場合、Runtimeで指定されたInvokeErrorタイプにマッピングする必要があります。これは、Difyが異なるエラーに対して異なる後続処理を実行できるようにするためです。
Runtime Errors:
さらに詳しいインターフェースメソッドについては、インターフェースドキュメント:Modelを参照してください。
この記事で取り上げた完全なコードファイルについては、GitHubコードリポジトリをご覧ください。
3. プラグインのデバッグ
プラグインの開発が完了したら、次にプラグインが正常に動作するかどうかをテストする必要があります。詳細については、以下を参照してください。
プラグインのデバッグ方法4. プラグインの公開
プラグインをDify マーケットプレイスに公開する場合は、以下を参照してください。
Difyマーケットプレイスへの公開さらに詳しく
クイックスタート:
プラグインインターフェースドキュメント:
Last updated